Skip to content

Commit

Permalink
change AESGCMstate to AESState for expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
ccfelius committed Nov 5, 2024
1 parent 7fdf254 commit 46644e3
Show file tree
Hide file tree
Showing 9 changed files with 275 additions and 66 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension)
project(${TARGET_NAME})
include_directories(src/include)
add_subdirectory(src)
include_directories(../duckdb/third_party/httplib/include)

# by now do this manually
set(EXTENSION_SOURCES src/simple_encryption_extension.cpp
Expand All @@ -26,6 +27,7 @@ set(EXTENSION_SOURCES src/simple_encryption_extension.cpp
src/core/functions/function_data/encrypt_function_data.cpp
src/core/functions/table/encrypt_table.cpp
src/core/utils/simple_encryption_utils.cpp
src/core/crypto/crypto_primitives.cpp
)

build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES})
Expand Down
4 changes: 0 additions & 4 deletions extension_config.cmake
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
# This file is included by DuckDB's build system. It specifies which extension to load

# Extension from this repo
duckdb_extension_load(simple_encryption
LOAD_TEST
SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}
)

# Load httpfs extension
duckdb_extension_load(httpfs)

# Any extra extensions that should be built
# e.g.: duckdb_extension_load(json)
5 changes: 5 additions & 0 deletions src/core/crypto/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(EXTENSION_SOURCES
${EXTENSION_SOURCES}
${CMAKE_CURRENT_SOURCE_DIR}/crypto_primitives.cpp
PARENT_SCOPE
)
151 changes: 151 additions & 0 deletions src/core/crypto/crypto_primitives.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#include "simple_encryption/core/crypto/crypto_primitives.hpp"
#include "mbedtls_wrapper.hpp"
#include <iostream>
#include "duckdb/common/common.hpp"
#include <stdio.h>

// OpenSSL functions
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/rand.h>

namespace duckdb {

void sha256(const char *in, size_t in_len, hash_bytes &out) {
duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(in, in_len, (char *)out);
}

void hmac256(const std::string &message, const char *secret, size_t secret_len, hash_bytes &out) {
duckdb_mbedtls::MbedTlsWrapper::Hmac256(secret, secret_len, message.data(), message.size(), (char *)out);
}

void hmac256(std::string message, hash_bytes secret, hash_bytes &out) {
hmac256(message, (char *)secret, sizeof(hash_bytes), out);
}

void hex256(hash_bytes &in, hash_str &out) {
const char *hex = "0123456789abcdef";
unsigned char *pin = in;
unsigned char *pout = out;
for (; pin < in + sizeof(in); pout += 2, pin++) {
pout[0] = hex[(*pin >> 4) & 0xF];
pout[1] = hex[*pin & 0xF];
}
}

const EVP_CIPHER *GetCipher(const string &key) {
// For now, we only support GCM ciphers
switch (key.size()) {
case 16:
return EVP_aes_128_gcm();
case 24:
return EVP_aes_192_gcm();
case 32:
return EVP_aes_256_gcm();
default:
throw InternalException("Invalid AES key length");
}
}

AESStateSSL::AESStateSSL() : gcm_context(EVP_CIPHER_CTX_new()) {
if (!(gcm_context)) {
throw InternalException("AES GCM failed with initializing context");
}
}

AESStateSSL::~AESStateSSL() {
// Clean up
EVP_CIPHER_CTX_free(gcm_context);
}

bool AESStateSSL::IsOpenSSL() {
return ssl;
}

void AESStateSSL::GenerateRandomData(data_ptr_t data, idx_t len) {
// generate random bytes for nonce
RAND_bytes(data, len);
}

void AESStateSSL::InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const string *key) {
mode = ENCRYPT;

if (1 != EVP_EncryptInit_ex(gcm_context, GetCipher(*key), NULL, const_data_ptr_cast(key->data()), iv)) {
throw InternalException("EncryptInit failed");
}
}

void AESStateSSL::InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, const string *key) {
mode = DECRYPT;

if (1 != EVP_DecryptInit_ex(gcm_context, GetCipher(*key), NULL, const_data_ptr_cast(key->data()), iv)) {
throw InternalException("DecryptInit failed");
}
}

size_t AESStateSSL::Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, idx_t out_len) {

switch (mode) {
case ENCRYPT:
if (1 != EVP_EncryptUpdate(gcm_context, data_ptr_cast(out), reinterpret_cast<int *>(&out_len),
const_data_ptr_cast(in), (int)in_len)) {
throw InternalException("Encryption failed at OpenSSL EVP_EncryptUpdate");
}
break;

case DECRYPT:
if (1 != EVP_DecryptUpdate(gcm_context, data_ptr_cast(out), reinterpret_cast<int *>(&out_len),
const_data_ptr_cast(in), (int)in_len)) {

throw InternalException("Decryption failed at OpenSSL EVP_DecryptUpdate");
}
break;
}

if (out_len != in_len) {
throw InternalException("AES GCM failed, in- and output lengths differ");
}

return out_len;
}

size_t AESStateSSL::Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) {
auto text_len = out_len;

switch (mode) {
case ENCRYPT:
if (1 != EVP_EncryptFinal_ex(gcm_context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len))) {
throw InternalException("EncryptFinal failed");
}
text_len += out_len;
// The computed tag is written at the end of a chunk
if (1 != EVP_CIPHER_CTX_ctrl(gcm_context, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) {
throw InternalException("Calculating the tag failed");
}
return text_len;
case DECRYPT:
// Set expected tag value
if (!EVP_CIPHER_CTX_ctrl(gcm_context, EVP_CTRL_GCM_SET_TAG, tag_len, tag)) {
throw InternalException("Finalizing tag failed");
}
// EVP_DecryptFinal() will return an error code if final block is not correctly formatted.
int ret = EVP_DecryptFinal_ex(gcm_context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len));
text_len += out_len;

if (ret > 0) {
// success
return text_len;
}
throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?");
}
}

Check warning on line 141 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md)

'duckdb::AESStateSSL::Finalize': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_extension.vcxproj]

Check warning on line 141 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md)

'duckdb::AESStateSSL::Finalize': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_loadable_extension.vcxproj]

Check warning on line 141 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md)

'duckdb::AESStateSSL::Finalize': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_extension.vcxproj]

Check warning on line 141 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md)

'duckdb::AESStateSSL::Finalize': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_loadable_extension.vcxproj]

Check warning on line 141 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64_rtools, x64-mingw-static)

control reaches end of non-void function [-Wreturn-type]

Check warning on line 141 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64_rtools, x64-mingw-static)

control reaches end of non-void function [-Wreturn-type]

} // namespace duckdb

extern "C" {

// Call the member function through the factory object
DUCKDB_EXTENSION_API AESStateSSLFactory *CreateSSLFactory() {
return new AESStateSSLFactory();
};
}
2 changes: 1 addition & 1 deletion src/core/functions/function_data/encrypt_function_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ unique_ptr<FunctionData> EncryptFunctionData::Copy() const {
bool EncryptFunctionData::Equals(const FunctionData &other_p) const {
auto &other = (const EncryptFunctionData &)other_p;
// fix this to return the right id
return false;
return true;
}

unique_ptr<FunctionData> EncryptFunctionData::EncryptBind(ClientContext &context, ScalarFunction &bound_function,
Expand Down
111 changes: 51 additions & 60 deletions src/core/functions/scalar/encrypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,65 +24,69 @@ namespace simple_encryption {

namespace core {

shared_ptr<EncryptionState> InitializeCryptoState() {
shared_ptr<EncryptionUtil> GetEncryptionUtil(ExpressionState &state) {
auto &func_expr = (BoundFunctionExpression &)state.expr;
auto &info = (EncryptFunctionData &)*func_expr.bind_info;
// get Database config
auto &config = DBConfig::GetConfig(*info.context.db);
return config.encryption_util;
}

// auto &info = (CSRFunctionData &)*func_expr.bind_info;
// auto simple_encryption_state = info.context.registered_state->Get<SimpleEncryptionState>("simple_encryption");
//
// if (!simple_encryption_state) {
// throw MissingExtensionException(
// "The simple_encryption extension has not been loaded");
// }

// for now just harcode MBEDTLS here
shared_ptr<EncryptionState> encryption_state =
duckdb_mbedtls::MbedTlsWrapper::AESGCMStateMBEDTLSFactory()
.CreateEncryptionState();
return encryption_state;
shared_ptr<EncryptionState> InitializeCryptoState(ExpressionState &state) {
auto encryption_state = GetEncryptionUtil(state);

if (!encryption_state) {
return duckdb_mbedtls::MbedTlsWrapper::AESGCMStateMBEDTLSFactory()
.CreateEncryptionState();
}

return encryption_state->CreateEncryptionState();
}

shared_ptr<EncryptionState> InitializeEncryption() {
shared_ptr<EncryptionState> InitializeEncryption(ExpressionState &state) {

// For now, hardcode everything
// for some reason, this is 12
const string key = TEST_KEY;
unsigned char iv[16];
unsigned char iv[12];
memcpy((void *)iv, "12345678901", 12);
//
// // TODO; construct nonce based on immutable ROW_ID + hash(col_name)
// iv[12] = 0x00;
// iv[13] = 0x00;
// iv[14] = 0x00;
// iv[15] = 0x00;

// TODO; construct nonce based on immutable ROW_ID + hash(col_name)
iv[12] = 0x00;
iv[13] = 0x00;
iv[14] = 0x00;
iv[15] = 0x00;

auto encryption_state = InitializeCryptoState();
encryption_state->InitializeEncryption(iv, 16, &key);
auto encryption_state = InitializeCryptoState(state);
// encryption_state->GenerateRandomData(iv, 12);
encryption_state->InitializeEncryption(iv, 12, &key);

return encryption_state;
}

shared_ptr<EncryptionState> InitializeDecryption() {
shared_ptr<EncryptionState> InitializeDecryption(ExpressionState &state) {

// For now, hardcode everything
const string key = TEST_KEY;
unsigned char iv[16];
unsigned char iv[12];
memcpy((void *)iv, "12345678901", 12);
//
// // TODO; construct nonce based on immutable ROW_ID + hash(col_name)
iv[12] = 0x00;
iv[13] = 0x00;
iv[14] = 0x00;
iv[15] = 0x00;
// iv[12] = 0x00;
// iv[13] = 0x00;
// iv[14] = 0x00;
// iv[15] = 0x00;

auto decryption_state = InitializeCryptoState();
auto decryption_state = InitializeCryptoState(state);
decryption_state->InitializeDecryption(iv, 16, &key);

return decryption_state;
}

inline const uint8_t *DecryptValue(uint8_t *buffer, size_t size) {
inline const uint8_t *DecryptValue(uint8_t *buffer, size_t size, ExpressionState &state) {

// Initialize Encryption
auto encryption_state = InitializeDecryption();
auto encryption_state = InitializeDecryption(state);
uint8_t decryption_buffer[MAX_BUFFER_SIZE];
uint8_t *temp_buf = decryption_buffer;

Expand All @@ -92,7 +96,7 @@ inline const uint8_t *DecryptValue(uint8_t *buffer, size_t size) {
}

bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer,
size_t size, const uint8_t *value){
size_t size, const uint8_t *value, ExpressionState &state){

// cast encrypted data to blob back and forth
// to check whether data will be lost with casting
Expand All @@ -105,7 +109,7 @@ bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer,
"Original Encrypted Data differs from Unblobbed Encrypted Data");
}

auto decrypted_data = DecryptValue(buffer, size);
auto decrypted_data = DecryptValue(buffer, size, state);
if (memcmp(decrypted_data, value, size) != 0) {
throw InvalidInputException(
"Original Data differs from Decrypted Data");
Expand All @@ -114,37 +118,23 @@ bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer,
return true;
}

shared_ptr<EncryptionUtil> GetEncryptionUtil(ExpressionState &state) {
auto &func_expr = (BoundFunctionExpression &)state.expr;
auto &info = (EncryptFunctionData &)*func_expr.bind_info;
// get Database config
auto &config = DBConfig::GetConfig(*info.context.db);
return config.encryption_util;
}

static void EncryptData(DataChunk &args, ExpressionState &state,
Vector &result) {

// auto &func_expr = (BoundFunctionExpression &)state.expr;
// // bind_info ptr is null
// auto &info = (EncryptFunctionData &)*func_expr.bind_info;
// // get Database config
// auto &config = DBConfig::GetConfig(*info.context.db);
// auto encryption_util = config.encryption_util;

auto &name_vector = args.data[0];
// auto encryption_state = InitializeEncryption(state);

// actually, fix to get encryption_state from db config
auto encryption_state = InitializeEncryption();
// we do this here, we actually need to keep track of a pointer all the time
uint8_t encryption_buffer[MAX_BUFFER_SIZE];
uint8_t *buffer = encryption_buffer;

// TODO; handle all different input types
UnaryExecutor::Execute<string_t, string_t>(
name_vector, result, args.size(), [&](string_t name) {

// we do this here, we actually need to keep track of a pointer all the time
uint8_t encryption_buffer[MAX_BUFFER_SIZE];
uint8_t *buffer = encryption_buffer;

// For now; new encryption state for every new value
// does this has to do with multithreading or something?
auto encryption_state = InitializeEncryption(state);
auto size = name.GetSize();
//std::fill(encryption_buffer, encryption_buffer + size, 0);
// round the size to multiple of 16 for encryption efficiency
Expand All @@ -159,11 +149,11 @@ static void EncryptData(DataChunk &args, ExpressionState &state,
string_t encrypted_data(reinterpret_cast<const char *>(buffer), size);
auto printable_encrypted_data = Blob::ToString(encrypted_data);

D_ASSERT(CheckEncryption(printable_encrypted_data, buffer, size, reinterpret_cast<const_data_ptr_t>(name.GetData())) == 1);
D_ASSERT(CheckEncryption(printable_encrypted_data, buffer, size, reinterpret_cast<const_data_ptr_t>(name.GetData()), state) == 1);

// attach the tag at the end of the encrypted data
unsigned char tag[16];
encryption_state->Finalize(buffer, 0, tag, 16);
// unsigned char tag[16];
// encryption_state->Finalize(buffer, 0, tag, 16);

// buffer pointer stays at the start haha
// buffer -= size;
Expand All @@ -177,7 +167,8 @@ ScalarFunctionSet GetEncryptionFunction() {

// TODO; support all available types for encryption
for (auto &type : LogicalType::AllTypes()) {
set.AddFunction(ScalarFunction({type}, LogicalType::VARCHAR, EncryptData));
set.AddFunction(ScalarFunction({type}, LogicalType::VARCHAR, EncryptData,
EncryptFunctionData::EncryptBind));
}

return set;
Expand Down
Loading

0 comments on commit 46644e3

Please sign in to comment.