From 64182d84ccb5e137aa20157dc746df409bd5e812 Mon Sep 17 00:00:00 2001 From: ccfelius Date: Thu, 7 Nov 2024 15:01:21 +0100 Subject: [PATCH] fixed string in vector allocation --- src/core/functions/scalar/encrypt.cpp | 49 +++++++++++++-------------- test/sql/simple_encryption.test | 4 +-- 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/src/core/functions/scalar/encrypt.cpp b/src/core/functions/scalar/encrypt.cpp index deeb14b..0820a74 100644 --- a/src/core/functions/scalar/encrypt.cpp +++ b/src/core/functions/scalar/encrypt.cpp @@ -132,37 +132,34 @@ bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer, template typename std::enable_if::value || std::is_floating_point::value, T>::type -ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) { +ConvertCipherText(Vector &vector, uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) { T encrypted_data; memcpy(&encrypted_data, buffer_p, sizeof(T)); return encrypted_data; } // Handle string_t type and convert to Base64 +template +typename std::enable_if::value, T>::type +ConvertCipherText(Vector &vector, uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) { + string_t input(reinterpret_cast(buffer_p), data_size); + size_t base64_size = Blob::ToBase64Size(input); + string_t output = StringVector::EmptyString(vector, base64_size); + + // Convert blob to base64 + Blob::ToBase64(input, output.GetDataWriteable()); + + return output; +} + +// TODO: for decryption, convert a string to blob and then decrypt and then return string_t? //template //typename std::enable_if::value, T>::type //ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) { -// // Create a blob from the encrypted buffer data -// string_t blob(reinterpret_cast(buffer_p), data_size); -// -// // Define a base64 output buffer large enough to store the encoded result -// size_t base64_size = Blob::ToBase64Size(blob); -// unique_ptr base64_output(new char[base64_size]); -// -// // Convert blob to base64 and store it in the output buffer -// Blob::ToBase64(blob, base64_output.get()); -// -// // Return the base64-encoded result as a new string_t -// return string_t(blob.GetString()); +// // +// return string_t(reinterpret_cast(buffer_p), data_size); //} -// TODO: for decryption, convert a string to blob and then decrypt and then return string_t? -template -typename std::enable_if::value, T>::type -ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) { - return string_t(reinterpret_cast(buffer_p), data_size); -} - // Catch-all for unsupported types template typename std::enable_if::value && !std::is_floating_point::value && !std::is_same::value, T>::type @@ -211,7 +208,7 @@ void ExecuteEncryptExecutor(Vector &vector, Vector &result, idx_t size, Expressi // Encrypt data encryption_state->Process(byte_array, data_size, buffer_p, data_size); - T encrypted_data = ConvertCipherText(buffer_p, data_size, byte_array); + T encrypted_data = ConvertCipherText(result, buffer_p, data_size, byte_array); #if 0 D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast(name.GetData()), state) == 1); @@ -269,7 +266,7 @@ void ExecuteDecryptExecutor(Vector &vector, Vector &result, idx_t size, Expressi // Encrypt data encryption_state->Process(byte_array, data_size, buffer_p, data_size); - T decrypted_data = ConvertCipherText(buffer_p, data_size, byte_array); + T decrypted_data = ConvertCipherText(result, buffer_p, data_size, byte_array); #if 0 D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast(name.GetData()), state) == 1); @@ -343,12 +340,12 @@ ScalarFunctionSet GetEncryptionFunction() { set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, LogicalTypeId::BIGINT, EncryptData, EncryptFunctionData::EncryptBind)); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BLOB, EncryptData, - EncryptFunctionData::EncryptBind)); - -// set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, EncryptData, +// set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BLOB, EncryptData, // EncryptFunctionData::EncryptBind)); + set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, EncryptData, + EncryptFunctionData::EncryptBind)); + return set; } diff --git a/test/sql/simple_encryption.test b/test/sql/simple_encryption.test index 50138dc..2220da2 100644 --- a/test/sql/simple_encryption.test +++ b/test/sql/simple_encryption.test @@ -39,11 +39,11 @@ SELECT decrypt(4095259532786215143, '0123456789112345'); query I SELECT encrypt('testtest', '0123456789112345'); ---- -\xF6N\xCEt\xE4]\xA6L +9k7OdORdpkw= #VARCHAR query I -SELECT decrypt('\xF6N\xCEt\xE4]\xA6L', '0123456789112345'); +SELECT decrypt('9k7OdORdpkw=', '0123456789112345'); ---- testtest