diff --git a/CMakeLists.txt b/CMakeLists.txt index c266743..2035db9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension) project(${TARGET_NAME}) include_directories(src/include) add_subdirectory(src) +include_directories(../duckdb/third_party/httplib/include) # by now do this manually set(EXTENSION_SOURCES src/simple_encryption_extension.cpp @@ -26,6 +27,7 @@ set(EXTENSION_SOURCES src/simple_encryption_extension.cpp src/core/functions/function_data/encrypt_function_data.cpp src/core/functions/table/encrypt_table.cpp src/core/utils/simple_encryption_utils.cpp + src/core/crypto/crypto_primitives.cpp ) build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES}) diff --git a/extension_config.cmake b/extension_config.cmake index 688eec3..400779c 100644 --- a/extension_config.cmake +++ b/extension_config.cmake @@ -1,13 +1,9 @@ # This file is included by DuckDB's build system. It specifies which extension to load - # Extension from this repo duckdb_extension_load(simple_encryption LOAD_TEST SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR} ) -# Load httpfs extension -duckdb_extension_load(httpfs) - # Any extra extensions that should be built # e.g.: duckdb_extension_load(json) \ No newline at end of file diff --git a/src/core/crypto/CMakeLists.txt b/src/core/crypto/CMakeLists.txt new file mode 100644 index 0000000..db95d8f --- /dev/null +++ b/src/core/crypto/CMakeLists.txt @@ -0,0 +1,5 @@ +set(EXTENSION_SOURCES + ${EXTENSION_SOURCES} + ${CMAKE_CURRENT_SOURCE_DIR}/crypto_primitives.cpp + PARENT_SCOPE +) \ No newline at end of file diff --git a/src/core/crypto/crypto_primitives.cpp b/src/core/crypto/crypto_primitives.cpp new file mode 100644 index 0000000..1c40209 --- /dev/null +++ b/src/core/crypto/crypto_primitives.cpp @@ -0,0 +1,151 @@ +#include "simple_encryption/core/crypto/crypto_primitives.hpp" +#include "mbedtls_wrapper.hpp" +#include +#include "duckdb/common/common.hpp" +#include + +// OpenSSL functions +#include +#include +#include + +namespace duckdb { + +void sha256(const char *in, size_t in_len, hash_bytes &out) { + duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(in, in_len, (char *)out); +} + +void hmac256(const std::string &message, const char *secret, size_t secret_len, hash_bytes &out) { + duckdb_mbedtls::MbedTlsWrapper::Hmac256(secret, secret_len, message.data(), message.size(), (char *)out); +} + +void hmac256(std::string message, hash_bytes secret, hash_bytes &out) { + hmac256(message, (char *)secret, sizeof(hash_bytes), out); +} + +void hex256(hash_bytes &in, hash_str &out) { + const char *hex = "0123456789abcdef"; + unsigned char *pin = in; + unsigned char *pout = out; + for (; pin < in + sizeof(in); pout += 2, pin++) { + pout[0] = hex[(*pin >> 4) & 0xF]; + pout[1] = hex[*pin & 0xF]; + } +} + +const EVP_CIPHER *GetCipher(const string &key) { + // For now, we only support GCM ciphers + switch (key.size()) { + case 16: + return EVP_aes_128_gcm(); + case 24: + return EVP_aes_192_gcm(); + case 32: + return EVP_aes_256_gcm(); + default: + throw InternalException("Invalid AES key length"); + } +} + +AESStateSSL::AESStateSSL() : gcm_context(EVP_CIPHER_CTX_new()) { + if (!(gcm_context)) { + throw InternalException("AES GCM failed with initializing context"); + } +} + +AESStateSSL::~AESStateSSL() { + // Clean up + EVP_CIPHER_CTX_free(gcm_context); +} + +bool AESStateSSL::IsOpenSSL() { + return ssl; +} + +void AESStateSSL::GenerateRandomData(data_ptr_t data, idx_t len) { + // generate random bytes for nonce + RAND_bytes(data, len); +} + +void AESStateSSL::InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const string *key) { + mode = ENCRYPT; + + if (1 != EVP_EncryptInit_ex(gcm_context, GetCipher(*key), NULL, const_data_ptr_cast(key->data()), iv)) { + throw InternalException("EncryptInit failed"); + } +} + +void AESStateSSL::InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, const string *key) { + mode = DECRYPT; + + if (1 != EVP_DecryptInit_ex(gcm_context, GetCipher(*key), NULL, const_data_ptr_cast(key->data()), iv)) { + throw InternalException("DecryptInit failed"); + } +} + +size_t AESStateSSL::Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, idx_t out_len) { + + switch (mode) { + case ENCRYPT: + if (1 != EVP_EncryptUpdate(gcm_context, data_ptr_cast(out), reinterpret_cast(&out_len), + const_data_ptr_cast(in), (int)in_len)) { + throw InternalException("Encryption failed at OpenSSL EVP_EncryptUpdate"); + } + break; + + case DECRYPT: + if (1 != EVP_DecryptUpdate(gcm_context, data_ptr_cast(out), reinterpret_cast(&out_len), + const_data_ptr_cast(in), (int)in_len)) { + + throw InternalException("Decryption failed at OpenSSL EVP_DecryptUpdate"); + } + break; + } + + if (out_len != in_len) { + throw InternalException("AES GCM failed, in- and output lengths differ"); + } + + return out_len; +} + +size_t AESStateSSL::Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) { + auto text_len = out_len; + + switch (mode) { + case ENCRYPT: + if (1 != EVP_EncryptFinal_ex(gcm_context, data_ptr_cast(out) + out_len, reinterpret_cast(&out_len))) { + throw InternalException("EncryptFinal failed"); + } + text_len += out_len; + // The computed tag is written at the end of a chunk + if (1 != EVP_CIPHER_CTX_ctrl(gcm_context, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) { + throw InternalException("Calculating the tag failed"); + } + return text_len; + case DECRYPT: + // Set expected tag value + if (!EVP_CIPHER_CTX_ctrl(gcm_context, EVP_CTRL_GCM_SET_TAG, tag_len, tag)) { + throw InternalException("Finalizing tag failed"); + } + // EVP_DecryptFinal() will return an error code if final block is not correctly formatted. + int ret = EVP_DecryptFinal_ex(gcm_context, data_ptr_cast(out) + out_len, reinterpret_cast(&out_len)); + text_len += out_len; + + if (ret > 0) { + // success + return text_len; + } + throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?"); + } +} + +} // namespace duckdb + +extern "C" { + +// Call the member function through the factory object +DUCKDB_EXTENSION_API AESStateSSLFactory *CreateSSLFactory() { + return new AESStateSSLFactory(); +}; +} \ No newline at end of file diff --git a/src/core/functions/function_data/encrypt_function_data.cpp b/src/core/functions/function_data/encrypt_function_data.cpp index 6e17ac4..b37a93b 100644 --- a/src/core/functions/function_data/encrypt_function_data.cpp +++ b/src/core/functions/function_data/encrypt_function_data.cpp @@ -14,7 +14,7 @@ unique_ptr EncryptFunctionData::Copy() const { bool EncryptFunctionData::Equals(const FunctionData &other_p) const { auto &other = (const EncryptFunctionData &)other_p; // fix this to return the right id - return false; + return true; } unique_ptr EncryptFunctionData::EncryptBind(ClientContext &context, ScalarFunction &bound_function, diff --git a/src/core/functions/scalar/encrypt.cpp b/src/core/functions/scalar/encrypt.cpp index 995ff02..ed5d785 100644 --- a/src/core/functions/scalar/encrypt.cpp +++ b/src/core/functions/scalar/encrypt.cpp @@ -24,65 +24,69 @@ namespace simple_encryption { namespace core { -shared_ptr InitializeCryptoState() { +shared_ptr GetEncryptionUtil(ExpressionState &state) { + auto &func_expr = (BoundFunctionExpression &)state.expr; + auto &info = (EncryptFunctionData &)*func_expr.bind_info; + // get Database config + auto &config = DBConfig::GetConfig(*info.context.db); + return config.encryption_util; +} -// auto &info = (CSRFunctionData &)*func_expr.bind_info; -// auto simple_encryption_state = info.context.registered_state->Get("simple_encryption"); -// -// if (!simple_encryption_state) { -// throw MissingExtensionException( -// "The simple_encryption extension has not been loaded"); -// } - - // for now just harcode MBEDTLS here - shared_ptr encryption_state = - duckdb_mbedtls::MbedTlsWrapper::AESGCMStateMBEDTLSFactory() - .CreateEncryptionState(); - return encryption_state; +shared_ptr InitializeCryptoState(ExpressionState &state) { + auto encryption_state = GetEncryptionUtil(state); + + if (!encryption_state) { + return duckdb_mbedtls::MbedTlsWrapper::AESGCMStateMBEDTLSFactory() + .CreateEncryptionState(); + } + + return encryption_state->CreateEncryptionState(); } -shared_ptr InitializeEncryption() { +shared_ptr InitializeEncryption(ExpressionState &state) { // For now, hardcode everything + // for some reason, this is 12 const string key = TEST_KEY; - unsigned char iv[16]; + unsigned char iv[12]; memcpy((void *)iv, "12345678901", 12); +// +// // TODO; construct nonce based on immutable ROW_ID + hash(col_name) +// iv[12] = 0x00; +// iv[13] = 0x00; +// iv[14] = 0x00; +// iv[15] = 0x00; - // TODO; construct nonce based on immutable ROW_ID + hash(col_name) - iv[12] = 0x00; - iv[13] = 0x00; - iv[14] = 0x00; - iv[15] = 0x00; - - auto encryption_state = InitializeCryptoState(); - encryption_state->InitializeEncryption(iv, 16, &key); + auto encryption_state = InitializeCryptoState(state); +// encryption_state->GenerateRandomData(iv, 12); + encryption_state->InitializeEncryption(iv, 12, &key); return encryption_state; } -shared_ptr InitializeDecryption() { +shared_ptr InitializeDecryption(ExpressionState &state) { // For now, hardcode everything const string key = TEST_KEY; - unsigned char iv[16]; + unsigned char iv[12]; memcpy((void *)iv, "12345678901", 12); // // // TODO; construct nonce based on immutable ROW_ID + hash(col_name) - iv[12] = 0x00; - iv[13] = 0x00; - iv[14] = 0x00; - iv[15] = 0x00; +// iv[12] = 0x00; +// iv[13] = 0x00; +// iv[14] = 0x00; +// iv[15] = 0x00; - auto decryption_state = InitializeCryptoState(); + auto decryption_state = InitializeCryptoState(state); decryption_state->InitializeDecryption(iv, 16, &key); return decryption_state; } -inline const uint8_t *DecryptValue(uint8_t *buffer, size_t size) { +inline const uint8_t *DecryptValue(uint8_t *buffer, size_t size, ExpressionState &state) { // Initialize Encryption - auto encryption_state = InitializeDecryption(); + auto encryption_state = InitializeDecryption(state); uint8_t decryption_buffer[MAX_BUFFER_SIZE]; uint8_t *temp_buf = decryption_buffer; @@ -92,7 +96,7 @@ inline const uint8_t *DecryptValue(uint8_t *buffer, size_t size) { } bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer, - size_t size, const uint8_t *value){ + size_t size, const uint8_t *value, ExpressionState &state){ // cast encrypted data to blob back and forth // to check whether data will be lost with casting @@ -105,7 +109,7 @@ bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer, "Original Encrypted Data differs from Unblobbed Encrypted Data"); } - auto decrypted_data = DecryptValue(buffer, size); + auto decrypted_data = DecryptValue(buffer, size, state); if (memcmp(decrypted_data, value, size) != 0) { throw InvalidInputException( "Original Data differs from Decrypted Data"); @@ -114,37 +118,23 @@ bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer, return true; } -shared_ptr GetEncryptionUtil(ExpressionState &state) { - auto &func_expr = (BoundFunctionExpression &)state.expr; - auto &info = (EncryptFunctionData &)*func_expr.bind_info; - // get Database config - auto &config = DBConfig::GetConfig(*info.context.db); - return config.encryption_util; -} - static void EncryptData(DataChunk &args, ExpressionState &state, Vector &result) { -// auto &func_expr = (BoundFunctionExpression &)state.expr; -// // bind_info ptr is null -// auto &info = (EncryptFunctionData &)*func_expr.bind_info; -// // get Database config -// auto &config = DBConfig::GetConfig(*info.context.db); -// auto encryption_util = config.encryption_util; - auto &name_vector = args.data[0]; +// auto encryption_state = InitializeEncryption(state); - // actually, fix to get encryption_state from db config - auto encryption_state = InitializeEncryption(); + // we do this here, we actually need to keep track of a pointer all the time + uint8_t encryption_buffer[MAX_BUFFER_SIZE]; + uint8_t *buffer = encryption_buffer; // TODO; handle all different input types UnaryExecutor::Execute( name_vector, result, args.size(), [&](string_t name) { - // we do this here, we actually need to keep track of a pointer all the time - uint8_t encryption_buffer[MAX_BUFFER_SIZE]; - uint8_t *buffer = encryption_buffer; - + // For now; new encryption state for every new value + // does this has to do with multithreading or something? + auto encryption_state = InitializeEncryption(state); auto size = name.GetSize(); //std::fill(encryption_buffer, encryption_buffer + size, 0); // round the size to multiple of 16 for encryption efficiency @@ -159,11 +149,11 @@ static void EncryptData(DataChunk &args, ExpressionState &state, string_t encrypted_data(reinterpret_cast(buffer), size); auto printable_encrypted_data = Blob::ToString(encrypted_data); - D_ASSERT(CheckEncryption(printable_encrypted_data, buffer, size, reinterpret_cast(name.GetData())) == 1); + D_ASSERT(CheckEncryption(printable_encrypted_data, buffer, size, reinterpret_cast(name.GetData()), state) == 1); // attach the tag at the end of the encrypted data - unsigned char tag[16]; - encryption_state->Finalize(buffer, 0, tag, 16); +// unsigned char tag[16]; +// encryption_state->Finalize(buffer, 0, tag, 16); // buffer pointer stays at the start haha // buffer -= size; @@ -177,7 +167,8 @@ ScalarFunctionSet GetEncryptionFunction() { // TODO; support all available types for encryption for (auto &type : LogicalType::AllTypes()) { - set.AddFunction(ScalarFunction({type}, LogicalType::VARCHAR, EncryptData)); + set.AddFunction(ScalarFunction({type}, LogicalType::VARCHAR, EncryptData, + EncryptFunctionData::EncryptBind)); } return set; diff --git a/src/include/simple_encryption/core/crypto/crypto_primitives.hpp b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp new file mode 100644 index 0000000..324d562 --- /dev/null +++ b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include "duckdb/common/encryption_state.hpp" +#include "duckdb/common/helper.hpp" + +#include +#include + + +typedef struct evp_cipher_ctx_st EVP_CIPHER_CTX; + +namespace duckdb { + +typedef unsigned char hash_bytes[32]; +typedef unsigned char hash_str[64]; + +void sha256(const char *in, size_t in_len, hash_bytes &out); + +void hmac256(const std::string &message, const char *secret, size_t secret_len, hash_bytes &out); + +void hmac256(std::string message, hash_bytes secret, hash_bytes &out); + +void hex256(hash_bytes &in, hash_str &out); + +class DUCKDB_EXTENSION_API AESStateSSL : public duckdb::EncryptionState { + +public: + explicit AESStateSSL(); + ~AESStateSSL() override; + +public: + bool IsOpenSSL() override; + void InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const std::string *key) override; + void InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, const std::string *key) override; + size_t Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, idx_t out_len) override; + size_t Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) override; + void GenerateRandomData(data_ptr_t data, idx_t len) override; + +private: + bool ssl = true; + EVP_CIPHER_CTX *gcm_context; + Mode mode; +}; + +} // namespace duckdb + +extern "C" { + +class DUCKDB_EXTENSION_API AESStateSSLFactory : public duckdb::EncryptionUtil { +public: + explicit AESStateSSLFactory() { + } + + duckdb::shared_ptr CreateEncryptionState() const override { + return duckdb::make_shared_ptr(); + } + + ~AESStateSSLFactory() override { + } +}; +} \ No newline at end of file diff --git a/src/include/simple_encryption/core/module.hpp b/src/include/simple_encryption/core/module.hpp index 49347fa..18c8f0f 100644 --- a/src/include/simple_encryption/core/module.hpp +++ b/src/include/simple_encryption/core/module.hpp @@ -9,7 +9,6 @@ struct CoreModule { public: static void Register(DatabaseInstance &db); - }; } // namespace core diff --git a/src/simple_encryption_extension.cpp b/src/simple_encryption_extension.cpp index 18c1556..636bd8b 100644 --- a/src/simple_encryption_extension.cpp +++ b/src/simple_encryption_extension.cpp @@ -17,6 +17,7 @@ #include "duckdb/main/connection_manager.hpp" #include #include "simple_encryption/core/module.hpp" +#include "simple_encryption/core/crypto/crypto_primitives.hpp" namespace duckdb { @@ -27,6 +28,9 @@ static void LoadInternal(DatabaseInstance &instance) { // Register the SimpleEncryptionState for all connections auto &config = DBConfig::GetConfig(instance); + + // set pointer to OpenSSL encryption state + config.encryption_util = make_shared_ptr(); config.extension_callbacks.push_back(make_uniq()); // Register the SimpleEncryptionState for all connections