Skip to content

Commit

Permalink
fixed string in vector allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
ccfelius committed Nov 7, 2024
1 parent c22135a commit 64182d8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 28 deletions.
49 changes: 23 additions & 26 deletions src/core/functions/scalar/encrypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,37 +132,34 @@ bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer,

template <typename T>
typename std::enable_if<std::is_integral<T>::value || std::is_floating_point<T>::value, T>::type
ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
ConvertCipherText(Vector &vector, uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
T encrypted_data;
memcpy(&encrypted_data, buffer_p, sizeof(T));
return encrypted_data;
}

// Handle string_t type and convert to Base64
template <typename T>
typename std::enable_if<std::is_same<T, string_t>::value, T>::type
ConvertCipherText(Vector &vector, uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
string_t input(reinterpret_cast<const char *>(buffer_p), data_size);
size_t base64_size = Blob::ToBase64Size(input);
string_t output = StringVector::EmptyString(vector, base64_size);

// Convert blob to base64
Blob::ToBase64(input, output.GetDataWriteable());

return output;
}

// TODO: for decryption, convert a string to blob and then decrypt and then return string_t?
//template <typename T>
//typename std::enable_if<std::is_same<T, string_t>::value, T>::type
//ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
// // Create a blob from the encrypted buffer data
// string_t blob(reinterpret_cast<const char *>(buffer_p), data_size);
//
// // Define a base64 output buffer large enough to store the encoded result
// size_t base64_size = Blob::ToBase64Size(blob);
// unique_ptr<char[]> base64_output(new char[base64_size]);
//
// // Convert blob to base64 and store it in the output buffer
// Blob::ToBase64(blob, base64_output.get());
//
// // Return the base64-encoded result as a new string_t
// return string_t(blob.GetString());
// //
// return string_t(reinterpret_cast<const char *>(buffer_p), data_size);
//}

// TODO: for decryption, convert a string to blob and then decrypt and then return string_t?
template <typename T>
typename std::enable_if<std::is_same<T, string_t>::value, T>::type
ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
return string_t(reinterpret_cast<const char *>(buffer_p), data_size);
}

// Catch-all for unsupported types
template <typename T>
typename std::enable_if<!std::is_integral<T>::value && !std::is_floating_point<T>::value && !std::is_same<T, string_t>::value, T>::type
Expand Down Expand Up @@ -211,7 +208,7 @@ void ExecuteEncryptExecutor(Vector &vector, Vector &result, idx_t size, Expressi

// Encrypt data
encryption_state->Process(byte_array, data_size, buffer_p, data_size);
T encrypted_data = ConvertCipherText<T>(buffer_p, data_size, byte_array);
T encrypted_data = ConvertCipherText<T>(result, buffer_p, data_size, byte_array);

#if 0
D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast<const_data_ptr_t>(name.GetData()), state) == 1);
Expand Down Expand Up @@ -269,7 +266,7 @@ void ExecuteDecryptExecutor(Vector &vector, Vector &result, idx_t size, Expressi

// Encrypt data
encryption_state->Process(byte_array, data_size, buffer_p, data_size);
T decrypted_data = ConvertCipherText<T>(buffer_p, data_size, byte_array);
T decrypted_data = ConvertCipherText<T>(result, buffer_p, data_size, byte_array);

#if 0
D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast<const_data_ptr_t>(name.GetData()), state) == 1);
Expand Down Expand Up @@ -343,12 +340,12 @@ ScalarFunctionSet GetEncryptionFunction() {
set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, LogicalTypeId::BIGINT, EncryptData,
EncryptFunctionData::EncryptBind));

set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BLOB, EncryptData,
EncryptFunctionData::EncryptBind));

// set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, EncryptData,
// set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BLOB, EncryptData,
// EncryptFunctionData::EncryptBind));

set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, EncryptData,
EncryptFunctionData::EncryptBind));

return set;
}

Expand Down
4 changes: 2 additions & 2 deletions test/sql/simple_encryption.test
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ SELECT decrypt(4095259532786215143, '0123456789112345');
query I
SELECT encrypt('testtest', '0123456789112345');
----
\xF6N\xCEt\xE4]\xA6L
9k7OdORdpkw=

#VARCHAR
query I
SELECT decrypt('\xF6N\xCEt\xE4]\xA6L', '0123456789112345');
SELECT decrypt('9k7OdORdpkw=', '0123456789112345');
----
testtest

Expand Down

0 comments on commit 64182d8

Please sign in to comment.