diff --git a/src/core/crypto/crypto_primitives.cpp b/src/core/crypto/crypto_primitives.cpp index 534f576..e26c6ea 100644 --- a/src/core/crypto/crypto_primitives.cpp +++ b/src/core/crypto/crypto_primitives.cpp @@ -33,7 +33,6 @@ void hex256(hash_bytes &in, hash_str &out) { } } -// it's nowhere defined so this is fine const EVP_CIPHER *GetCipher(const string &key, AESStateSSL::Algorithm algorithm) { switch(algorithm) { @@ -75,15 +74,15 @@ const EVP_CIPHER *GetCipher(const string &key, AESStateSSL::Algorithm algorithm) } } -AESStateSSL::AESStateSSL() : gcm_context(EVP_CIPHER_CTX_new()) { - if (!(gcm_context)) { +AESStateSSL::AESStateSSL() : context(EVP_CIPHER_CTX_new()) { + if (!(context)) { throw InternalException("AES GCM failed with initializing context"); } } AESStateSSL::~AESStateSSL() { // Clean up - EVP_CIPHER_CTX_free(gcm_context); + EVP_CIPHER_CTX_free(context); } bool AESStateSSL::IsOpenSSL() { @@ -113,7 +112,7 @@ void AESStateSSL::InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const mode = ENCRYPT; - if (1 != EVP_EncryptInit_ex(gcm_context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) { + if (1 != EVP_EncryptInit_ex(context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) { throw InternalException("EncryptInit failed"); } } @@ -121,7 +120,7 @@ void AESStateSSL::InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const void AESStateSSL::InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, const string *key) { mode = DECRYPT; - if (1 != EVP_DecryptInit_ex(gcm_context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) { + if (1 != EVP_DecryptInit_ex(context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) { throw InternalException("DecryptInit failed"); } } @@ -130,14 +129,14 @@ size_t AESStateSSL::Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, i switch (mode) { case ENCRYPT: - if (1 != EVP_EncryptUpdate(gcm_context, data_ptr_cast(out), reinterpret_cast(&out_len), + if (1 != EVP_EncryptUpdate(context, data_ptr_cast(out), reinterpret_cast(&out_len), const_data_ptr_cast(in), (int)in_len)) { throw InternalException("Encryption failed at OpenSSL EVP_EncryptUpdate"); } break; case DECRYPT: - if (1 != EVP_DecryptUpdate(gcm_context, data_ptr_cast(out), reinterpret_cast(&out_len), + if (1 != EVP_DecryptUpdate(context, data_ptr_cast(out), reinterpret_cast(&out_len), const_data_ptr_cast(in), (int)in_len)) { throw InternalException("Decryption failed at OpenSSL EVP_DecryptUpdate"); @@ -156,23 +155,35 @@ size_t AESStateSSL::Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_ auto text_len = out_len; switch (mode) { + case ENCRYPT: - if (1 != EVP_EncryptFinal_ex(gcm_context, data_ptr_cast(out) + out_len, reinterpret_cast(&out_len))) { + if (1 != EVP_EncryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast(&out_len))) { throw InternalException("EncryptFinal failed"); } - text_len += out_len; - // The computed tag is written at the end of a chunk - if (1 != EVP_CIPHER_CTX_ctrl(gcm_context, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) { + + if (algorithm == CTR) { + return text_len; + } + + // The computed tag is written at the end of a chunk for OCB and GCM + if (1 != EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_GET_TAG, tag_len, + tag)) { throw InternalException("Calculating the tag failed"); } return text_len; + case DECRYPT: - // Set expected tag value - if (!EVP_CIPHER_CTX_ctrl(gcm_context, EVP_CTRL_GCM_SET_TAG, tag_len, tag)) { - throw InternalException("Finalizing tag failed"); + + if (algorithm != CTR){ + // Set expected tag value + if (!EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_SET_TAG, tag_len, + tag)) { + throw InternalException("Finalizing tag failed"); + } } + // EVP_DecryptFinal() will return an error code if final block is not correctly formatted. - int ret = EVP_DecryptFinal_ex(gcm_context, data_ptr_cast(out) + out_len, reinterpret_cast(&out_len)); + int ret = EVP_DecryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast(&out_len)); text_len += out_len; if (ret > 0) { diff --git a/src/core/functions/scalar/encrypt.cpp b/src/core/functions/scalar/encrypt.cpp index ed5d785..56092f4 100644 --- a/src/core/functions/scalar/encrypt.cpp +++ b/src/core/functions/scalar/encrypt.cpp @@ -2,6 +2,7 @@ #define TEST_KEY "0123456789112345" #define MAX_BUFFER_SIZE 1024 +#define MAX_BUFFER_SIZE_2 8096 #include "duckdb.hpp" #include "duckdb/common/exception.hpp" @@ -48,8 +49,8 @@ shared_ptr InitializeEncryption(ExpressionState &state) { // For now, hardcode everything // for some reason, this is 12 const string key = TEST_KEY; - unsigned char iv[12]; - memcpy((void *)iv, "12345678901", 12); + unsigned char iv[16]; +// memcpy((void *)iv, "12345678901", 16); // // // TODO; construct nonce based on immutable ROW_ID + hash(col_name) // iv[12] = 0x00; @@ -58,8 +59,8 @@ shared_ptr InitializeEncryption(ExpressionState &state) { // iv[15] = 0x00; auto encryption_state = InitializeCryptoState(state); -// encryption_state->GenerateRandomData(iv, 12); - encryption_state->InitializeEncryption(iv, 12, &key); +// encryption_state->GenerateRandomData(iv, 16); + encryption_state->InitializeEncryption(iv, 16, &key); return encryption_state; } @@ -68,8 +69,8 @@ shared_ptr InitializeDecryption(ExpressionState &state) { // For now, hardcode everything const string key = TEST_KEY; - unsigned char iv[12]; - memcpy((void *)iv, "12345678901", 12); + unsigned char iv[16]; + memcpy((void *)iv, "12345678901", 16); // // // TODO; construct nonce based on immutable ROW_ID + hash(col_name) // iv[12] = 0x00; @@ -122,24 +123,32 @@ static void EncryptData(DataChunk &args, ExpressionState &state, Vector &result) { auto &name_vector = args.data[0]; -// auto encryption_state = InitializeEncryption(state); + auto encryption_state = InitializeEncryption(state); // we do this here, we actually need to keep track of a pointer all the time - uint8_t encryption_buffer[MAX_BUFFER_SIZE]; - uint8_t *buffer = encryption_buffer; +// uint8_t encryption_buffer[MAX_BUFFER_SIZE]; +// uint8_t *buffer = encryption_buffer; // TODO; handle all different input types UnaryExecutor::Execute( name_vector, result, args.size(), [&](string_t name) { + // renew for each value + uint8_t encryption_buffer[MAX_BUFFER_SIZE]; + uint8_t *buffer = encryption_buffer; // For now; new encryption state for every new value // does this has to do with multithreading or something? - auto encryption_state = InitializeEncryption(state); auto size = name.GetSize(); //std::fill(encryption_buffer, encryption_buffer + size, 0); // round the size to multiple of 16 for encryption efficiency // size = (size + 15) & ~15; + unsigned char iv[16]; + const string key = TEST_KEY; + + encryption_state->GenerateRandomData(iv, 16); + encryption_state->InitializeEncryption(iv, 16, &key); + encryption_state->Process( reinterpret_cast(name.GetData()), size, buffer, size); @@ -147,6 +156,8 @@ static void EncryptData(DataChunk &args, ExpressionState &state, sizeof(encryption_buffer) / sizeof(encryption_buffer[0])); string_t encrypted_data(reinterpret_cast(buffer), size); + + auto printable_encrypted_data = Blob::ToString(encrypted_data); D_ASSERT(CheckEncryption(printable_encrypted_data, buffer, size, reinterpret_cast(name.GetData()), state) == 1); @@ -155,9 +166,6 @@ static void EncryptData(DataChunk &args, ExpressionState &state, // unsigned char tag[16]; // encryption_state->Finalize(buffer, 0, tag, 16); - // buffer pointer stays at the start haha -// buffer -= size; - return printable_encrypted_data; }); } diff --git a/src/include/simple_encryption/core/crypto/crypto_primitives.hpp b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp index c2de1bf..c8c2f08 100644 --- a/src/include/simple_encryption/core/crypto/crypto_primitives.hpp +++ b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp @@ -44,11 +44,11 @@ class DUCKDB_EXTENSION_API AESStateSSL : public duckdb::EncryptionState { private: bool ssl = true; - EVP_CIPHER_CTX *gcm_context; + EVP_CIPHER_CTX *context; Mode mode; // default value is GCM - Algorithm algorithm = GCM; + Algorithm algorithm = CTR; }; } // namespace duckdb