Skip to content

Commit

Permalink
add decryption tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ccfelius committed Nov 7, 2024
1 parent 523e1b1 commit 188ede7
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 147 deletions.
237 changes: 96 additions & 141 deletions src/core/functions/scalar/encrypt.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#define DUCKDB_EXTENSION_MAIN

#define TEST_KEY "0123456789112345"

// what is the maximum size of biggest type in duckdb
#define MAX_BUFFER_SIZE 1024
#define MAX_BUFFER_SIZE_2 8096

#include "duckdb.hpp"
#include "duckdb/common/exception.hpp"
Expand Down Expand Up @@ -73,27 +74,6 @@ shared_ptr<EncryptionState> InitializeCryptoState(ExpressionState &state) {
return encryption_state->CreateEncryptionState();
}

shared_ptr<EncryptionState> InitializeEncryption(ExpressionState &state) {

// For now, hardcode everything
// for some reason, this is 12
const string key = TEST_KEY;
unsigned char iv[16];
// memcpy((void *)iv, "12345678901", 16);
//
// // TODO; construct nonce based on immutable ROW_ID + hash(col_name)
// iv[12] = 0x00;
// iv[13] = 0x00;
// iv[14] = 0x00;
// iv[15] = 0x00;

auto encryption_state = InitializeCryptoState(state);
// encryption_state->GenerateRandomData(iv, 16);
// encryption_state->InitializeEncryption(iv, 16, &key);

return encryption_state;
}

shared_ptr<EncryptionState> InitializeDecryption(ExpressionState &state) {

// For now, hardcode everything
Expand Down Expand Up @@ -148,88 +128,30 @@ bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer,
return true;
}

static void DecryptData(DataChunk &args, ExpressionState &state,
Vector &result) {

auto &name_vector = args.data[0];
// auto encryption_state = InitializeEncryption(state);

auto const size = sizeof(string_t);

// TODO; handle all different input types
UnaryExecutor::Execute<string_t, string_t>(
name_vector, result, args.size(), [&](string_t name) {

// renew for each value
uint8_t decryption_buffer[MAX_BUFFER_SIZE];
uint8_t *buffer_p = decryption_buffer;
// For now; new encryption state for every new value
// does this has to do with multithreading or something?
// the size is suddenly 1, but we should just get the size of the input type...
auto name_size = name.GetSize();

// round the size to multiple of 16 for encryption efficiency
// size = (size + 15) & ~15;

unsigned char iv[16];
const string key = TEST_KEY;
auto encryption_state = InitializeCryptoState(state);

// fix IV for now
memcpy((void *)iv, "12345678901", 16);
//
// // TODO; construct nonce based on immutable ROW_ID + hash(col_name)
iv[12] = 0x00;
iv[13] = 0x00;
iv[14] = 0x00;
iv[15] = 0x00;

// encryption_state->GenerateRandomData(iv, 16);
encryption_state->InitializeDecryption(iv, 16, &key);

// at some point, input gets invalid
auto input = reinterpret_cast<const_data_ptr_t>(name.GetData());
encryption_state->Process(input, name_size, buffer_p, name_size);

#if 0
D_ASSERT(MAX_BUFFER_SIZE ==
sizeof(encryption_buffer) / sizeof(encryption_buffer[0]));
#endif

string_t decrypted_data(reinterpret_cast<const char *>(buffer_p), name_size);
auto printable_decrypted_data = Blob::ToString(decrypted_data);

#if 0
D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast<const_data_ptr_t>(name.GetData()), state) == 1);
#endif

// attach the tag at the end of the encrypted data
unsigned char tag[16];
// this does not do anything for CTR
encryption_state->Finalize(buffer_p, 0, tag, 16);
return printable_decrypted_data;
});
}


// FIX: make C++11 compatible
// Generated code
// misschien duckdb types doen ipv die andere dingen
template <typename T>
T EncryptAndConvert(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
T ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {

if constexpr (std::is_integral<T>::value || std::is_floating_point<T>::value) {

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 138 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64_rtools, x64-mingw-static)

'if constexpr' only available with '-std=c++17' or '-std=gnu++17'
T encrypted_data;
memcpy(&encrypted_data, buffer_p, sizeof(T));
return encrypted_data;

} else if constexpr (std::is_same<T, string_t>::value) {

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

constexpr if is a C++17 extension [-Wc++17-extensions]

Check warning on line 143 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64_rtools, x64-mingw-static)

'if constexpr' only available with '-std=c++17' or '-std=gnu++17'
// string_t decrypted_data(reinterpret_cast<const char *>(buffer_p), name_size);
// auto printable_decrypted_data = Blob::ToString(decrypted_data);
return string_t(reinterpret_cast<const char *>(buffer_p), data_size);

} else {
InvalidInputException("Unsupported type for encryption");
InvalidInputException("Unsupported type for Encryption");
}
}

// FIX: make C++11 compatible
// Generated code
template <typename T>
size_t GetSizeOfInput(const T &input) {

Expand All @@ -249,7 +171,7 @@ size_t GetSizeOfInput(const T &input) {


template <typename T>
void ExecuteWithUnaryExecutor(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) {
void ExecuteEncryptExecutor(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) {

// TODO: put this in the state of the extension
uint8_t encryption_buffer[MAX_BUFFER_SIZE];
Expand All @@ -258,61 +180,107 @@ void ExecuteWithUnaryExecutor(Vector &vector, Vector &result, idx_t size, Expres
unsigned char iv[16];
auto encryption_state = InitializeCryptoState(state);

// TODO; construct nonce based on immutable ROW_ID + hash(col_name)
// TODO: construct nonce based on immutable ROW_ID + hash(col_name)
memcpy(iv, "12345678901", 12);
iv[12] = iv[13] = iv[14] = iv[15] = 0x00;

UnaryExecutor::Execute<T, T>(vector, result, size, [&](T input) -> T {

unsigned char byte_array[sizeof(T)];
auto data_size = GetSizeOfInput(input);
encryption_state->InitializeEncryption(iv, 16, &key_t);
// at some point, input gets invalid
// auto input_data = reinterpret_cast<const_data_ptr_t>(input);

// klopt dit wel?
unsigned char byte_array[sizeof(T)];
// Convert input to byte array for processing
memcpy(byte_array, &input, sizeof(T));

// Optionally make it `const unsigned char*`
const unsigned char* byte_ptr = byte_array;

// Encrypt data
encryption_state->Process(byte_array, data_size, buffer_p, data_size);

// T encrypted_data(reinterpret_cast<const char *>(buffer_p), data_size);
T encrypted_data = EncryptAndConvert<T>(buffer_p, data_size, byte_array);
T encrypted_data = ConvertCipherText<T>(buffer_p, data_size, byte_array);

#if 0
D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast<const_data_ptr_t>(name.GetData()), state) == 1);
#endif

// attach the tag at the end of the encrypted data
unsigned char tag[16];
// this does not do anything for CTR
// this does not do anything for CTR and therefore can be skipped
encryption_state->Finalize(buffer_p, 0, tag, 16);

return encrypted_data;
});
}

// Helper function that dispatches the runtime type to the appropriate templated function
void ExecuteWithRuntimeType(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) {
void ExecuteEncrypt(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) {
// Check the vector type and call the correct templated version
switch (vector.GetType().id()) {
case LogicalTypeId::INTEGER:
ExecuteEncryptExecutor<int32_t>(vector, result, size, state, key_t);
break;
case LogicalTypeId::BIGINT:
ExecuteEncryptExecutor<int64_t>(vector, result, size, state, key_t);
break;
case LogicalTypeId::VARCHAR:
ExecuteEncryptExecutor<string_t>(vector, result, size, state, key_t);
break;
default:
throw NotImplementedException("Unsupported type for Encryption");
}
}

template <typename T>
void ExecuteDecryptExecutor(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) {

auto gettypeid = vector.GetType();
// TODO: put this in the state of the extension
uint8_t encryption_buffer[MAX_BUFFER_SIZE];
uint8_t *buffer_p = encryption_buffer;

unsigned char iv[16];
auto encryption_state = InitializeCryptoState(state);

// TODO: construct nonce based on immutable ROW_ID + hash(col_name)
memcpy(iv, "12345678901", 12);
iv[12] = iv[13] = iv[14] = iv[15] = 0x00;

UnaryExecutor::Execute<T, T>(vector, result, size, [&](T input) -> T {
unsigned char byte_array[sizeof(T)];
auto data_size = GetSizeOfInput(input);
encryption_state->InitializeDecryption(iv, 16, &key_t);

// Convert input to byte array for processing
memcpy(byte_array, &input, sizeof(T));

// Encrypt data
encryption_state->Process(byte_array, data_size, buffer_p, data_size);
T decrypted_data = ConvertCipherText<T>(buffer_p, data_size, byte_array);

#if 0
D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast<const_data_ptr_t>(name.GetData()), state) == 1);
#endif

// attach the tag at the end of the encrypted data
unsigned char tag[16];
// this does not do anything for CTR and therefore can be skipped
encryption_state->Finalize(buffer_p, 0, tag, 16);

return decrypted_data;
});
}

// Helper function that dispatches the runtime type to the appropriate templated function
void ExecuteDecrypt(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) {
// Check the vector type and call the correct templated version
switch (vector.GetType().id()) {
case LogicalTypeId::INTEGER:
ExecuteWithUnaryExecutor<int32_t>(vector, result, size, state, key_t);
ExecuteDecryptExecutor<int32_t>(vector, result, size, state, key_t);
break;
case LogicalTypeId::BIGINT:
ExecuteWithUnaryExecutor<int64_t>(vector, result, size, state, key_t);
ExecuteDecryptExecutor<int64_t>(vector, result, size, state, key_t);
break;
case LogicalTypeId::VARCHAR:
ExecuteWithUnaryExecutor<string_t>(vector, result, size, state, key_t);
ExecuteDecryptExecutor<string_t>(vector, result, size, state, key_t);
break;
// Add cases for other types as needed
default:
throw NotImplementedException("Unsupported type for UnaryExecutor");
throw NotImplementedException("Unsupported type for Encryption");
}
}

Expand All @@ -328,22 +296,27 @@ static void EncryptData(DataChunk &args, ExpressionState &state, Vector &result)
const string key_t = ConstantVector::GetData<string_t>(key_vector)[0].GetString();

// can we not pass by reference?
ExecuteWithRuntimeType(value_vector, result, args.size(), state, key_t);
ExecuteEncrypt(value_vector, result, args.size(), state, key_t);
}

static void DecryptData(DataChunk &args, ExpressionState &state, Vector &result) {

auto &value_vector = args.data[0];

// Get the encryption key
auto &key_vector = args.data[1];
D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);

// Fetch the encryption key as a constant string
const string key_t = ConstantVector::GetData<string_t>(key_vector)[0].GetString();

// can we not pass by reference?
ExecuteEncrypt(value_vector, result, args.size(), state, key_t);
}

ScalarFunctionSet GetEncryptionFunction() {
ScalarFunctionSet set("encrypt");

// input is column of any type, key is of type VARCHAR, output is of same type
// set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalTypeId::INTEGER, EncryptData,
// EncryptFunctionData::EncryptBind));
//
// set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, LogicalTypeId::BIGINT, EncryptData,
// EncryptFunctionData::EncryptBind));
//
// set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, EncryptData,
// EncryptFunctionData::EncryptBind));

set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalTypeId::INTEGER, EncryptData,
EncryptFunctionData::EncryptBind));

Expand All @@ -353,14 +326,8 @@ ScalarFunctionSet GetEncryptionFunction() {
set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BLOB, EncryptData,
EncryptFunctionData::EncryptBind));

// TODO; support all available types for encryption
// for (auto &type : LogicalType::AllTypes()) {
//
// // input is column of any type, key is of type VARCHAR, output is of same type
// set.AddFunction(ScalarFunction({type, LogicalType::VARCHAR}, type, EncryptData,
// EncryptFunctionData::EncryptBind));
//
// }
// set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, EncryptData,
// EncryptFunctionData::EncryptBind));

return set;
}
Expand All @@ -378,20 +345,8 @@ ScalarFunctionSet GetDecryptionFunction() {
set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, DecryptData,
EncryptFunctionData::EncryptBind));

set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalTypeId::BLOB, DecryptData,
EncryptFunctionData::EncryptBind));

set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, LogicalTypeId::BLOB, DecryptData,
EncryptFunctionData::EncryptBind));

set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BLOB, DecryptData,
EncryptFunctionData::EncryptBind));

// TODO; support all available types for encryption
// for (auto &type : LogicalType::AllTypes()) {
// set.AddFunction(ScalarFunction({type, LogicalType::VARCHAR}, type, DecryptData,
// EncryptFunctionData::EncryptBind));
// }
// set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BLOB, DecryptData,
// EncryptFunctionData::EncryptBind));

return set;
}
Expand Down
Loading

0 comments on commit 188ede7

Please sign in to comment.