From 3b658a11e506bc5606e8a971ac52d39f6e5e07e4 Mon Sep 17 00:00:00 2001 From: ccfelius Date: Wed, 18 Dec 2024 10:30:56 +0100 Subject: [PATCH 1/2] wip --- experiments.sql | 46 ++++++++ .../functions/scalar/encrypt_vectorized.cpp | 103 ++++++++++-------- 2 files changed, 105 insertions(+), 44 deletions(-) create mode 100644 experiments.sql diff --git a/experiments.sql b/experiments.sql new file mode 100644 index 0000000..b361111 --- /dev/null +++ b/experiments.sql @@ -0,0 +1,46 @@ +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val uint128)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1, 'val': (range & 7) * (cast(1 as uint128) << 124) + (range >>2)} s from range(100000000); +from pragma_storage_info('tst') where row_group_id=90; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234'|| cast(range >> 3 as string))} s from range(100000000); +from tst limit 10; +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; +from pragma_storage_info('tst') where row_group_id=90; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456'|| cast(range >> 4 as string))} s from range(100000000); +from tst limit 10; +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +from pragma_storage_info('tst') where row_group_id=90; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890'|| cast(range >> 5 as string))} s from range(100000000); +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +from pragma_storage_info('tst') where row_group_id=90; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678'|| cast(range >> 6 as string))} s from range(100000000); +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +from pragma_storage_info('tst') where row_group_id=90; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; + +create or replace table tst(s struct(val blob)); +insert into tst select {'val':encode('01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678990123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234'|| cast(range >> 7 as string))} s from range(100000000); +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +from pragma_storage_info('tst') where row_group_id=0; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>11,'ctr': (range&2047)<<1, 'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567899012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456'|| cast(range >> 8 as string))} s from range(100000000); +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +from pragma_storage_info('tst') where row_group_id=90; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; + diff --git a/src/core/functions/scalar/encrypt_vectorized.cpp b/src/core/functions/scalar/encrypt_vectorized.cpp index 35acb3a..704b13d 100644 --- a/src/core/functions/scalar/encrypt_vectorized.cpp +++ b/src/core/functions/scalar/encrypt_vectorized.cpp @@ -167,10 +167,12 @@ shared_ptr GetEncryptionState(ExpressionState &state) { } // todo; template -LogicalType CreateEINTtypeStruct() { +LogicalType CreateBLOBStruct() { return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, {"nonce_lo", LogicalType::UBIGINT}, - {"value", LogicalType::INTEGER}}); + {"counter", LogicalType::UINTEGER}, + {"cipher", LogicalType::TINYINT}, + {"value", LogicalType::BLOB}}); } LogicalType CreateEVARtypeStruct() { @@ -180,61 +182,60 @@ LogicalType CreateEVARtypeStruct() { } template -void EncryptToEtype(LogicalType result_struct, Vector &input_vector, +uint32_t CalculateBlockCounter(uint32_t counter) { + return ceil(counter * sizeof(T) / 16); +} + +template +void EncryptVectorized(LogicalType result_struct, Vector &input_vector, uint64_t size, ExpressionState &state, Vector &result) { + // local, global and encryption state auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state); - - // this is the global state auto simple_encryption_state = GetSimpleEncryptionState(state); auto encryption_state = GetEncryptionState(state); - - // Get Key from Bind auto key = GetKey(state); // Reset the reference of the result vector - Vector struct_vector(result_struct, size); + Vector struct_vector(CreateBLOBStruct(), size); result.ReferenceAndSetType(struct_vector); - // ValidityMask &result_validity = FlatVector::Validity(result); - - if ((simple_encryption_state->counter == 0) || (HasSpace(simple_encryption_state, size) == false)) { - // generate new random IV and reset counter (if strart or if there is no space left) - SetIV(simple_encryption_state); - simple_encryption_state->counter = 0; - } - auto &children = StructVector::GetEntries(result); auto &nonce_hi = children[0]; + auto &nonce_lo = children[1]; + auto &counter = children[2]; + auto &cipher = children[3]; + + // result vector containing encrypted data + auto &blob = children[4]; + nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR); + nonce_lo->SetVectorType(VectorType::CONSTANT_VECTOR); - auto nonce_lo = simple_encryption_state->iv[1]; + // put counter also in the local state + counter->SetVectorType(VectorType::FLAT_VECTOR); + cipher->SetVectorType(VectorType::FLAT_VECTOR); - using ENCRYPTED_TYPE = StructTypeTernary; - using PLAINTEXT_TYPE = PrimitiveType; +// auto nonce_lo = simple_encryption_state->iv[1]; - encryption_state->InitializeEncryption( - reinterpret_cast(simple_encryption_state->iv), 16, - reinterpret_cast(key)); +// using ENCRYPTED_TYPE = StructTypeTernary; +// using PLAINTEXT_TYPE = PrimitiveType; + // but how do we increase the counter? - GenericExecutor::ExecuteUnary( - input_vector, result, size, [&](PLAINTEXT_TYPE input) { - simple_encryption_state->iv[1]++; - simple_encryption_state->counter++; - encryption_state->InitializeEncryption( - reinterpret_cast(simple_encryption_state->iv), 16, - reinterpret_cast(key)); + encryption_state->InitializeEncryption( + reinterpret_cast(simple_encryption_state->iv), 16, + key); - T encrypted_data = - ProcessVectorizedEncrypt(encryption_state, result, input.val, - lstate.buffer_p); + T encrypted_data = + ProcessVectorizedEncrypt(encryption_state, result, input.val, + lstate.buffer_p); return ENCRYPTED_TYPE{simple_encryption_state->iv[0], simple_encryption_state->iv[1], encrypted_data}; - }); + }); } @@ -272,46 +273,60 @@ void DecryptFromEtype(Vector &input_vector, uint64_t size, } -static void EncryptDataToEtype(DataChunk &args, ExpressionState &state, +static void EncryptDataVectorized(DataChunk &args, ExpressionState &state, Vector &result) { auto &input_vector = args.data[0]; auto vector_type = input_vector.GetType(); auto size = args.size(); + // get src and dst vectors for searches + auto &src = args.data[2]; + auto &dst = args.data[3]; + + UnifiedVectorFormat vdata_src; + UnifiedVectorFormat vdata_dst; + + src.ToUnifiedFormat(args.size(), vdata_src); + dst.ToUnifiedFormat(args.size(), vdata_dst); + auto src_data = (int64_t *)vdata_src.data; + auto dst_data = (int64_t *)vdata_dst.data; + + ValidityMask &result_validity = FlatVector::Validity(result); + if (vector_type.IsNumeric()) { switch (vector_type.id()) { case LogicalTypeId::TINYINT: case LogicalTypeId::UTINYINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized(CreateEINTtypeStruct(), input_vector, size, state, result); case LogicalTypeId::SMALLINT: case LogicalTypeId::USMALLINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized(CreateEINTtypeStruct(), input_vector, size, state, result); case LogicalTypeId::INTEGER: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized(CreateEINTtypeStruct(), input_vector, size, state, result); case LogicalTypeId::UINTEGER: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized(CreateEINTtypeStruct(), input_vector, size, state, result); case LogicalTypeId::BIGINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized(CreateEINTtypeStruct(), input_vector, size, state, result); case LogicalTypeId::UBIGINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized(CreateEINTtypeStruct(), input_vector, size, state, result); case LogicalTypeId::FLOAT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized(CreateEINTtypeStruct(), input_vector, size, state, result); case LogicalTypeId::DOUBLE: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized(CreateEINTtypeStruct(), input_vector, size, state, result); default: throw NotImplementedException("Unsupported numeric type for encryption"); } } else if (vector_type.id() == LogicalTypeId::VARCHAR) { - return EncryptToEtype(CreateEVARtypeStruct(), input_vector, + return EncryptVectorized(CreateEVARtypeStruct(), input_vector, size, state, result); } else if (vector_type.IsNested()) { throw NotImplementedException( @@ -381,7 +396,7 @@ ScalarFunctionSet GetEncryptionStructFunction() { LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, {"nonce_lo", LogicalType::UBIGINT}, {"value", type}}), - EncryptDataToEtype, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init)); + EncryptDataVectorized, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init)); } return set; From 38cc872734832b710964bf4714b8b0ed9b6f74d0 Mon Sep 17 00:00:00 2001 From: ccfelius Date: Thu, 19 Dec 2024 21:25:37 +0100 Subject: [PATCH 2/2] refactoring and adding vectorized --- CMakeLists.txt | 2 +- src/core/functions/common.cpp | 33 +- .../function_data/encrypt_function_data.cpp | 10 +- src/core/functions/scalar/CMakeLists.txt | 2 +- src/core/functions/scalar/encrypt.cpp | 30 +- ...encrypt_to_etype.cpp => encrypt_naive.cpp} | 27 +- .../functions/scalar/encrypt_vectorized.cpp | 339 ++++++------------ .../core/functions/common.hpp | 5 +- .../function_data/encrypt_function_data.hpp | 3 +- .../core/functions/scalar.hpp | 2 + .../core/functions/scalar/encrypt.hpp | 17 +- test/sql/compression/dict_compression.test | 28 ++ test/sql/vectorized/vectorized_encrypt.test | 26 ++ 13 files changed, 243 insertions(+), 281 deletions(-) rename src/core/functions/scalar/{encrypt_to_etype.cpp => encrypt_naive.cpp} (94%) create mode 100644 test/sql/compression/dict_compression.test create mode 100644 test/sql/vectorized/vectorized_encrypt.test diff --git a/CMakeLists.txt b/CMakeLists.txt index 8879e92..543dbc3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,7 @@ set(EXTENSION_SOURCES src/core/module.cpp src/core/types.cpp src/core/functions/scalar/encrypt.cpp - src/core/functions/scalar/encrypt_to_etype.cpp + src/core/functions/scalar/encrypt_naive.cpp src/core/functions/scalar/encrypt_vectorized.cpp src/core/functions/function_data/encrypt_function_data.cpp src/core/functions/cast/varchar_cast.cpp diff --git a/src/core/functions/common.cpp b/src/core/functions/common.cpp index 2afe06f..2cc1e00 100644 --- a/src/core/functions/common.cpp +++ b/src/core/functions/common.cpp @@ -12,12 +12,26 @@ SimpleEncryptionFunctionLocalState::SimpleEncryptionFunctionLocalState(ClientCon iv[0] = iv[1] = 0; // maybe generate iv_high also already in the bind + // allocate depending in sizeof(T) * items_in_vector + // maybe already in registering the function! + size_t data_size; + LogicalType type = bind_data->type; - // for now do 512 bytes - buffer_length = 512; - encryption_buffer = arena.Allocate(buffer_length); + // todo; fix this for all other types + if (type == LogicalType::VARCHAR) { + // allocate buffer for encrypted data + data_size = 512; + } else { + // maybe we can also just do per vector for certain types, so more then 128 + data_size = GetTypeIdSize(type.InternalType()) * 128; + } - buffer_p = (data_ptr_t)encryption_buffer; + buffer_p = (data_ptr_t)arena.Allocate(data_size); + + if (bind_data->type.id() == LogicalTypeId::VARCHAR) { + // allocate buffer for encrypted data + buffer_p = (data_ptr_t)arena.Allocate(128); + } } unique_ptr @@ -25,12 +39,23 @@ SimpleEncryptionFunctionLocalState::Init(ExpressionState &state, const BoundFunc return make_uniq(state.GetContext(), static_cast(bind_data)); } +SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::Get(ExpressionState &state) { + auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast(); + return local_state; +} + SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::ResetAndGet(ExpressionState &state) { auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast(); local_state.arena.Reset(); return local_state; } +SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::AllocateAndGet(ExpressionState &state, idx_t buffer_size) { + auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast(); + local_state.arena.Allocate(buffer_size); + return local_state; +} + SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::ResetKeyAndGet(ExpressionState &state) { auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast(); local_state.arena.Reset(); diff --git a/src/core/functions/function_data/encrypt_function_data.cpp b/src/core/functions/function_data/encrypt_function_data.cpp index c70f987..17b6837 100644 --- a/src/core/functions/function_data/encrypt_function_data.cpp +++ b/src/core/functions/function_data/encrypt_function_data.cpp @@ -15,7 +15,7 @@ struct KeyData { }; unique_ptr EncryptFunctionData::Copy() const { - return make_uniq(context, key_name); + return make_uniq(context, key_name, type); } bool EncryptFunctionData::Equals(const FunctionData &other_p) const { @@ -64,6 +64,12 @@ EncryptFunctionData::EncryptBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { + auto &value = arguments[0]; + + if (arguments.size() != 2) { + throw BinderException("Encrypt Scalar Function requires two arguments"); + } + auto &key_child = arguments[1]; if (key_child->HasParameter()) { throw ParameterNotResolvedException(); @@ -81,7 +87,7 @@ EncryptFunctionData::EncryptBind(ClientContext &context, auto key_name = StringUtil::Lower(key_str); - return make_uniq(context, key_name); + return make_uniq(context, key_name, value->return_type); } } // namespace core } // namespace simple_encryption diff --git a/src/core/functions/scalar/CMakeLists.txt b/src/core/functions/scalar/CMakeLists.txt index 80d083f..a2771cc 100644 --- a/src/core/functions/scalar/CMakeLists.txt +++ b/src/core/functions/scalar/CMakeLists.txt @@ -1,6 +1,6 @@ set(EXTENSION_SOURCES ${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/encrypt.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/encrypt_to_etype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/encrypt_naive.cpp PARENT_SCOPE ) \ No newline at end of file diff --git a/src/core/functions/scalar/encrypt.cpp b/src/core/functions/scalar/encrypt.cpp index 30dfa46..a894d53 100644 --- a/src/core/functions/scalar/encrypt.cpp +++ b/src/core/functions/scalar/encrypt.cpp @@ -15,16 +15,38 @@ #include #include "duckdb/common/types/blob.hpp" #include "duckdb/main/connection_manager.hpp" -#include "simple_encryption/core/functions/scalar/encrypt.hpp" -#include "simple_encryption/core/functions/scalar.hpp" -#include "simple_encryption_state.hpp" +#include "duckdb/common/encryption_state.hpp" #include "duckdb/main/client_context.hpp" -#include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp" + +#include "simple_encryption_state.hpp" +#include "simple_encryption/core/functions/scalar.hpp" +#include "simple_encryption/core/functions/scalar/encrypt.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" namespace simple_encryption { namespace core { +EncryptFunctionData& VCryptBasicFun::GetEncryptionBindInfo(ExpressionState &state) { + auto &func_expr = (BoundFunctionExpression &)state.expr; + return (EncryptFunctionData &)*func_expr.bind_info; +} + +shared_ptr +VCryptBasicFun::GetSimpleEncryptionState(ExpressionState &state) { + auto &info = VCryptBasicFun::GetEncryptionBindInfo(state); + return info.context.registered_state->Get( + "simple_encryption"); +} +// TODO; maybe pass by reference or so +string* VCryptBasicFun::GetKey(ExpressionState &state) { + auto &info = VCryptBasicFun::GetEncryptionBindInfo(state); + return &info.key; +} + +shared_ptr VCryptBasicFun::GetEncryptionState(ExpressionState &state) { + return VCryptBasicFun::GetSimpleEncryptionState(state)->encryption_state; +} + shared_ptr GetEncryptionUtil(ExpressionState &state) { auto &func_expr = (BoundFunctionExpression &)state.expr; auto &info = (EncryptFunctionData &)*func_expr.bind_info; diff --git a/src/core/functions/scalar/encrypt_to_etype.cpp b/src/core/functions/scalar/encrypt_naive.cpp similarity index 94% rename from src/core/functions/scalar/encrypt_to_etype.cpp rename to src/core/functions/scalar/encrypt_naive.cpp index f4b0cba..130d852 100644 --- a/src/core/functions/scalar/encrypt_to_etype.cpp +++ b/src/core/functions/scalar/encrypt_naive.cpp @@ -167,21 +167,6 @@ GetSimpleEncryptionState(ExpressionState &state) { "simple_encryption"); } -// TODO; maybe pass by reference or so -string* GetKey(ExpressionState &state) { - auto &info = GetEncryptionBindInfo(state); - return &info.key; -} - -shared_ptr GetSimpleEncryptionStateLocal(ExpressionState &state) { - auto &info = GetEncryptionBindInfo(state); - // create a new local encryption state, but get the nonce etc. from the global state. - auto encryption_util = GetSimpleEncryptionState(state)->encryption_util; - - return info.context.registered_state->Get( - "simple_encryption")->encryption_util->CreateEncryptionState(); -} - bool HasSpace(shared_ptr simple_encryption_state, uint64_t size) { uint32_t max_value = ~0u; @@ -210,10 +195,6 @@ bool CheckGeneratedKeySize(const uint32_t size){ } } -shared_ptr GetEncryptionState(ExpressionState &state) { - return GetSimpleEncryptionState(state)->encryption_state; -} - // todo; template LogicalType CreateEINTtypeStruct() { return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, @@ -236,10 +217,10 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector, // this is the global state auto simple_encryption_state = GetSimpleEncryptionState(state); - auto encryption_state = GetEncryptionState(state); + auto encryption_state = VCryptBasicFun::GetEncryptionState(state); // Get Key from Bind - auto key = GetKey(state); + auto key = VCryptBasicFun::GetKey(state); // Reset the reference of the result vector Vector struct_vector(result_struct, size); @@ -295,10 +276,10 @@ void DecryptFromEtype(Vector &input_vector, uint64_t size, auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state); // global state auto simple_encryption_state = GetSimpleEncryptionState(state); - auto encryption_state = GetEncryptionState(state); + auto encryption_state = VCryptBasicFun::GetEncryptionState(state); // Get Key from Bind - auto key = GetKey(state); + auto key = VCryptBasicFun::GetKey(state); using ENCRYPTED_TYPE = StructTypeTernary; using PLAINTEXT_TYPE = PrimitiveType; diff --git a/src/core/functions/scalar/encrypt_vectorized.cpp b/src/core/functions/scalar/encrypt_vectorized.cpp index 704b13d..d53c010 100644 --- a/src/core/functions/scalar/encrypt_vectorized.cpp +++ b/src/core/functions/scalar/encrypt_vectorized.cpp @@ -17,8 +17,6 @@ #include #include "simple_encryption_state.hpp" -#include "simple_encryption/core/types.hpp" -#include "simple_encryption/core/crypto/crypto_primitives.hpp" #include "simple_encryption/core/functions/common.hpp" #include "simple_encryption/core/functions/scalar.hpp" #include "simple_encryption/core/functions/secrets.hpp" @@ -29,145 +27,7 @@ namespace simple_encryption { namespace core { -template -typename std::enable_if< - std::is_integral::value || std::is_floating_point::value, T>::type -ProcessVectorizedEncrypt(shared_ptr encryption_state, - Vector &result, T plaintext_data, uint8_t *buffer_p) { - T encrypted_data; - encryption_state->Process( - reinterpret_cast(&plaintext_data), sizeof(int32_t), - reinterpret_cast(&encrypted_data), sizeof(int32_t)); - return encrypted_data; -} - - -template -typename std::enable_if::value, T>::type -ProcessVectorizedEncrypt(shared_ptr encryption_state, - Vector &result, T plaintext_data, uint8_t *buffer_p) { - - auto &children = StructVector::GetEntries(result); - auto &result_vector = children[2]; - - // first encrypt the bytes of the string into a temp buffer_p - auto input_data = data_ptr_t(plaintext_data.GetData()); - auto value_size = plaintext_data.GetSize(); - - encryption_state->Process(input_data, value_size, buffer_p, value_size); - - // Convert the encrypted data to a BLOB - auto encrypted_data = - string_t(reinterpret_cast(buffer_p), value_size); - size_t base64_size = Blob::ToBase64Size(encrypted_data); - - // convert to Base64 into a newly allocated string in the result vector - T base64_data = StringVector::EmptyString(*result_vector, base64_size); - memset(base64_data.GetDataWriteable(), 0, 12); - Blob::ToBase64(encrypted_data, base64_data.GetDataWriteable()); - - return base64_data; -} - - -template -typename std::enable_if::value, T>::type -ProcessVectorizedDecrypt(shared_ptr encryption_state, - Vector &result, T base64_data, uint8_t *buffer_p) { - - // we cann just fix the blob_size - // first encrypt the bytes of the string into a temp buffer_p - size_t encrypted_size = Blob::FromBase64Size(base64_data); - size_t decrypted_size = encrypted_size; - Blob::FromBase64(base64_data, reinterpret_cast(buffer_p), - encrypted_size); - - D_ASSERT(encrypted_size <= base64_data.GetSize()); - - string_t decrypted_data = - StringVector::EmptyString(result, decrypted_size); - encryption_state->Process( - buffer_p, encrypted_size, - reinterpret_cast(decrypted_data.GetDataWriteable()), - decrypted_size); - - return decrypted_data; -} - -template -typename std::enable_if< - std::is_integral::value || std::is_floating_point::value, T>::type -ProcessVectorizedDecrypt(shared_ptr encryption_state, - Vector &result, T encrypted_data, uint8_t *buffer_p) { - T decrypted_data; - encryption_state->Process( - reinterpret_cast(&encrypted_data), sizeof(T), - reinterpret_cast(&decrypted_data), sizeof(T)); - return decrypted_data; -} - -EncryptFunctionData &GetEncryptionBindInfo(ExpressionState &state) { - auto &func_expr = (BoundFunctionExpression &)state.expr; - return (EncryptFunctionData &)*func_expr.bind_info; -} - -shared_ptr -GetSimpleEncryptionState(ExpressionState &state) { - - auto &info = GetEncryptionBindInfo(state); - return info.context.registered_state->Get( - "simple_encryption"); -} - -// TODO; maybe pass by reference or so -string* GetKey(ExpressionState &state) { - auto &info = GetEncryptionBindInfo(state); - return &info.key; -} - -shared_ptr GetSimpleEncryptionStateLocal(ExpressionState &state) { - auto &info = GetEncryptionBindInfo(state); - // create a new local encryption state, but get the nonce etc. from the global state. - auto encryption_util = GetSimpleEncryptionState(state)->encryption_util; - - return info.context.registered_state->Get( - "simple_encryption")->encryption_util->CreateEncryptionState(); -} - -bool HasSpace(shared_ptr simple_encryption_state, - uint64_t size) { - uint32_t max_value = ~0u; - if ((max_value - simple_encryption_state->counter) > size) { - return true; - } - return false; -} - - -void SetIV(shared_ptr simple_encryption_state) { - simple_encryption_state->iv[1] = 0; - simple_encryption_state->encryption_state->GenerateRandomData( - reinterpret_cast(simple_encryption_state->iv), 12); -} - -bool CheckGeneratedKeySize(const uint32_t size){ - - switch(size){ - case 16: - case 24: - case 32: - return true; - default: - return false; - } -} - -shared_ptr GetEncryptionState(ExpressionState &state) { - return GetSimpleEncryptionState(state)->encryption_state; -} - -// todo; template -LogicalType CreateBLOBStruct() { +LogicalType CreateEncryptionStruct() { return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, {"nonce_lo", LogicalType::UBIGINT}, {"counter", LogicalType::UINTEGER}, @@ -175,68 +35,88 @@ LogicalType CreateBLOBStruct() { {"value", LogicalType::BLOB}}); } -LogicalType CreateEVARtypeStruct() { - return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, - {"nonce_lo", LogicalType::UBIGINT}, - {"value", LogicalType::VARCHAR}}); -} - template -uint32_t CalculateBlockCounter(uint32_t counter) { - return ceil(counter * sizeof(T) / 16); -} - -template -void EncryptVectorized(LogicalType result_struct, Vector &input_vector, - uint64_t size, ExpressionState &state, - Vector &result) { +void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, Vector &result) { // local, global and encryption state auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state); - auto simple_encryption_state = GetSimpleEncryptionState(state); - auto encryption_state = GetEncryptionState(state); - auto key = GetKey(state); + auto simple_encryption_state = VCryptBasicFun::GetSimpleEncryptionState(state); + auto encryption_state = VCryptBasicFun::GetEncryptionState(state); + auto key = VCryptBasicFun::GetKey(state); - // Reset the reference of the result vector - Vector struct_vector(CreateBLOBStruct(), size); + Vector struct_vector(CreateEncryptionStruct(), size); result.ReferenceAndSetType(struct_vector); auto &children = StructVector::GetEntries(result); auto &nonce_hi = children[0]; auto &nonce_lo = children[1]; - auto &counter = children[2]; - auto &cipher = children[3]; + auto &counter_vec = children[2]; + auto &cipher_vec = children[3]; // result vector containing encrypted data auto &blob = children[4]; + // set the constant vectors nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR); nonce_lo->SetVectorType(VectorType::CONSTANT_VECTOR); - // put counter also in the local state - counter->SetVectorType(VectorType::FLAT_VECTOR); - cipher->SetVectorType(VectorType::FLAT_VECTOR); - -// auto nonce_lo = simple_encryption_state->iv[1]; - -// using ENCRYPTED_TYPE = StructTypeTernary; -// using PLAINTEXT_TYPE = PrimitiveType; - // but how do we increase the counter? - - - - encryption_state->InitializeEncryption( - reinterpret_cast(simple_encryption_state->iv), 16, - key); - - T encrypted_data = - ProcessVectorizedEncrypt(encryption_state, result, input.val, - lstate.buffer_p); - - return ENCRYPTED_TYPE{simple_encryption_state->iv[0], - simple_encryption_state->iv[1], encrypted_data}; - }); - + auto nonce_hi_64 = simple_encryption_state->iv[0]; + auto nonce_lo_32 = simple_encryption_state->iv[0]; + + // is not the pointer but really the actual value copied? + // Set constant vectors to a single value + nonce_hi->Reference(Value::UBIGINT(nonce_hi_64)); + nonce_hi->Reference(Value::UBIGINT(nonce_lo_32)); + + counter_vec->SetVectorType(VectorType::FLAT_VECTOR); + cipher_vec->SetVectorType(VectorType::FLAT_VECTOR); + + // Set the blob vector to dict vector for compressed execution + blob->SetVectorType(VectorType::DICTIONARY_VECTOR); + + encryption_state->InitializeEncryption(reinterpret_cast(simple_encryption_state->iv), 16, key); + + auto &blob_child = DictionaryVector::Child(*blob); + auto &blob_sel = DictionaryVector::SelVector(*blob); + + // we process in batches of 128 values, or we can do it with all and cut at each 128 * sizeof(T) bits (only works for similar lengths) + // fill 512 bytes, so 512 / sizeof(T) values and at least 128 values. + // note: this only works for fixed-size types + auto batch_size = 128 * sizeof(T); + auto total_size = sizeof(T) * size; + // and the cipher + + // todo: assign buffer_p with the right size + encryption_state->Process(reinterpret_cast(input_vector), total_size, lstate.buffer_p, total_size); + + auto index = 0; + auto batch_nr = 0; + // get counter from local state + uint32_t counter = 0; + uint8_t cipher = 0; + const size_t step = sizeof(T) / 16; + uint64_t buffer_offset; + + // TODO: for strings this all works slighly different + for(int i = 0; i + 128; i < (DEFAULT_STANDARD_VECTOR_SIZE / 128)){ + + buffer_offset = batch_nr * sizeof(T) * 128; + // Allocate space in the dictionary vector (i.e. blob_child) + string_t batch_data = StringVector::EmptyString(blob_child, batch_size); // value size + *(uint32_t*) batch_data.GetPrefixWriteable() = *(uint32_t *) lstate.buffer_p + buffer_offset; + memcpy(batch_data.GetDataWriteable(), lstate.buffer_p, batch_size); + + // set index in selection vector + for (int j = 0; j++; j < 128){ + cipher = j % step; + counter += (cipher == 0 && index != 0) ? 1 : 0; + cipher_vec->SetValue(index, Value::TINYINT(cipher)); + counter_vec->SetValue(index, Value::UINTEGER(counter)); + blob_sel.set_index(index, batch_nr); + index++; + } + batch_nr++; + } } @@ -247,29 +127,29 @@ void DecryptFromEtype(Vector &input_vector, uint64_t size, // local state (contains key, buffer, iv etc.) auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state); // global state - auto simple_encryption_state = GetSimpleEncryptionState(state); - auto encryption_state = GetEncryptionState(state); + auto simple_encryption_state = VCryptBasicFun::GetSimpleEncryptionState(state); + auto encryption_state = VCryptBasicFun::GetEncryptionState(state); // Get Key from Bind - auto key = GetKey(state); + auto key = VCryptBasicFun::GetKey(state); - using ENCRYPTED_TYPE = StructTypeTernary; - using PLAINTEXT_TYPE = PrimitiveType; - - GenericExecutor::ExecuteUnary( - input_vector, result, size, [&](ENCRYPTED_TYPE input) { - simple_encryption_state->iv[0] = input.a_val; - simple_encryption_state->iv[1] = input.b_val; - - encryption_state->InitializeDecryption( - reinterpret_cast(simple_encryption_state->iv), 12, - reinterpret_cast(key)); - - T decrypted_data = - ProcessVectorizedDecrypt(encryption_state, result, input.c_val, - lstate.buffer_p); - return decrypted_data; - }); +// using ENCRYPTED_TYPE = StructTypeTernary; +// using PLAINTEXT_TYPE = PrimitiveType; +// +// GenericExecutor::ExecuteUnary( +// input_vector, result, size, [&](ENCRYPTED_TYPE input) { +// simple_encryption_state->iv[0] = input.a_val; +// simple_encryption_state->iv[1] = input.b_val; +// +// encryption_state->InitializeDecryption( +// reinterpret_cast(simple_encryption_state->iv), 12, +// reinterpret_cast(key)); +// +// T decrypted_data = +// ProcessVectorizedDecrypt(encryption_state, result, input.c_val, +// lstate.buffer_p); +// return decrypted_data; +// }); } @@ -280,53 +160,44 @@ static void EncryptDataVectorized(DataChunk &args, ExpressionState &state, auto vector_type = input_vector.GetType(); auto size = args.size(); - // get src and dst vectors for searches - auto &src = args.data[2]; - auto &dst = args.data[3]; - - UnifiedVectorFormat vdata_src; - UnifiedVectorFormat vdata_dst; - - src.ToUnifiedFormat(args.size(), vdata_src); - dst.ToUnifiedFormat(args.size(), vdata_dst); - auto src_data = (int64_t *)vdata_src.data; - auto dst_data = (int64_t *)vdata_dst.data; - + UnifiedVectorFormat vdata_input; + input_vector.ToUnifiedFormat(args.size(), vdata_input); ValidityMask &result_validity = FlatVector::Validity(result); + auto vd = vdata_input.data; if (vector_type.IsNumeric()) { switch (vector_type.id()) { case LogicalTypeId::TINYINT: case LogicalTypeId::UTINYINT: - return EncryptVectorized(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((int8_t *)vdata_input.data, size, state, result); case LogicalTypeId::SMALLINT: case LogicalTypeId::USMALLINT: - return EncryptVectorized(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((int16_t *)vdata_input.data, size, state, result); case LogicalTypeId::INTEGER: - return EncryptVectorized(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((int32_t *)vdata_input.data, size, state, result); case LogicalTypeId::UINTEGER: - return EncryptVectorized(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((uint32_t *)vdata_input.data, size, state, result); case LogicalTypeId::BIGINT: - return EncryptVectorized(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((int64_t *)vdata_input.data, size, state, result); case LogicalTypeId::UBIGINT: - return EncryptVectorized(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((uint64_t *)vdata_input.data, size, state, result); case LogicalTypeId::FLOAT: - return EncryptVectorized(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((float *)vdata_input.data, size, state, result); case LogicalTypeId::DOUBLE: - return EncryptVectorized(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((double *)vdata_input.data, size, state, result); default: throw NotImplementedException("Unsupported numeric type for encryption"); } } else if (vector_type.id() == LogicalTypeId::VARCHAR) { - return EncryptVectorized(CreateEVARtypeStruct(), input_vector, + return EncryptVectorized((string_t *)vdata_input.data, size, state, result); } else if (vector_type.IsNested()) { throw NotImplementedException( @@ -338,7 +209,7 @@ static void EncryptDataVectorized(DataChunk &args, ExpressionState &state, } -static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, +static void DecryptDataVectorized(DataChunk &args, ExpressionState &state, Vector &result) { auto size = args.size(); @@ -387,7 +258,7 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, } } -ScalarFunctionSet GetEncryptionStructFunction() { +ScalarFunctionSet GetEncryptionVectorizedFunction() { ScalarFunctionSet set("encrypt_vectorized"); for (auto &type : LogicalType::AllTypes()) { @@ -402,7 +273,7 @@ ScalarFunctionSet GetEncryptionStructFunction() { return set; } -ScalarFunctionSet GetDecryptionStructFunction() { +ScalarFunctionSet GetDecryptionVectorizedFunction() { ScalarFunctionSet set("decrypt_vectorized"); for (auto &type : LogicalType::AllTypes()) { @@ -413,7 +284,7 @@ ScalarFunctionSet GetDecryptionStructFunction() { {"nonce_lo", nonce_type_b}, {"value", type}}), LogicalType::VARCHAR}, - type, DecryptDataFromEtype, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init)); + type, DecryptDataVectorized, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init)); } } @@ -431,10 +302,10 @@ ScalarFunctionSet GetDecryptionStructFunction() { // Register functions //------------------------------------------------------------------------------ -void CoreScalarFunctions::RegisterEncryptDataStructScalarFunction( +void CoreScalarFunctions::RegisterEncryptVectorizedScalarFunction( DatabaseInstance &db) { - ExtensionUtil::RegisterFunction(db, GetEncryptionStructFunction()); - ExtensionUtil::RegisterFunction(db, GetDecryptionStructFunction()); + ExtensionUtil::RegisterFunction(db, GetEncryptionVectorizedFunction()); + ExtensionUtil::RegisterFunction(db, GetDecryptionVectorizedFunction()); } } // namespace core } // namespace simple_encryption diff --git a/src/include/simple_encryption/core/functions/common.hpp b/src/include/simple_encryption/core/functions/common.hpp index c98f795..6e041ca 100644 --- a/src/include/simple_encryption/core/functions/common.hpp +++ b/src/include/simple_encryption/core/functions/common.hpp @@ -10,9 +10,8 @@ struct SimpleEncryptionFunctionLocalState : FunctionLocalState { public: ArenaAllocator arena; - - idx_t buffer_length; uint64_t iv[2]; + uint32_t counter = 0; // todo: key can be 16, 24 or 32 unsigned char key[16]; @@ -25,7 +24,9 @@ struct SimpleEncryptionFunctionLocalState : FunctionLocalState { explicit SimpleEncryptionFunctionLocalState(ClientContext &context, EncryptFunctionData *bind_data); static unique_ptr Init(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data); + static SimpleEncryptionFunctionLocalState &Get(ExpressionState &state); static SimpleEncryptionFunctionLocalState &ResetAndGet(ExpressionState &state); + static SimpleEncryptionFunctionLocalState &AllocateAndGet(ExpressionState &state, idx_t buffer_size); static SimpleEncryptionFunctionLocalState &ResetKeyAndGet(ExpressionState &state); }; diff --git a/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp b/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp index 82ba286..5d2462c 100644 --- a/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp +++ b/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp @@ -13,9 +13,10 @@ struct EncryptFunctionData : FunctionData { // Save the Key string key_name; string key; + LogicalType type; // BoundStatement relation; - EncryptFunctionData(ClientContext &context, string key_name) : context(context), key_name(key_name) { + EncryptFunctionData(ClientContext &context, string key_name, LogicalType type) : context(context), key_name(key_name), type(type) { // generate encryption key and store key = GetKeyFromSecret(context, key_name); } diff --git a/src/include/simple_encryption/core/functions/scalar.hpp b/src/include/simple_encryption/core/functions/scalar.hpp index 4c12e97..02453a6 100644 --- a/src/include/simple_encryption/core/functions/scalar.hpp +++ b/src/include/simple_encryption/core/functions/scalar.hpp @@ -10,11 +10,13 @@ struct CoreScalarFunctions { static void Register(duckdb::DatabaseInstance &db) { RegisterEncryptDataScalarFunction(db); RegisterEncryptDataStructScalarFunction(db); + RegisterEncryptVectorizedScalarFunction(db); } private: static void RegisterEncryptDataScalarFunction(duckdb::DatabaseInstance &db); static void RegisterEncryptDataStructScalarFunction(duckdb::DatabaseInstance &db); + static void RegisterEncryptVectorizedScalarFunction(duckdb::DatabaseInstance &db); }; } // namespace core diff --git a/src/include/simple_encryption/core/functions/scalar/encrypt.hpp b/src/include/simple_encryption/core/functions/scalar/encrypt.hpp index 92a6601..dc08f78 100644 --- a/src/include/simple_encryption/core/functions/scalar/encrypt.hpp +++ b/src/include/simple_encryption/core/functions/scalar/encrypt.hpp @@ -2,6 +2,7 @@ #include "simple_encryption/common.hpp" #include "duckdb/common/encryption_state.hpp" +#include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp" #ifndef DUCKDB_AMALGAMATION #include "duckdb/storage/object_cache.hpp" @@ -11,19 +12,17 @@ namespace simple_encryption { namespace core { -class SimpleEncryptKeys : public ObjectCacheEntry { +class VCryptBasicFun { public: - static SimpleEncryptKeys &Get(ClientContext &context); + // Fix this later + static VCryptBasicFun &Get(ClientContext &context); public: - void AddKey(const string &key_name, const string &key); - bool HasKey(const string &key_name) const; - const string &GetKey(const string &key_name) const; - -public: - static string ObjectType(); - string GetObjectType() override; + static string* GetKey(ExpressionState &state); + static EncryptFunctionData &GetEncryptionBindInfo(ExpressionState &state); + static shared_ptr GetSimpleEncryptionState(ExpressionState &state); + static shared_ptr GetEncryptionState(ExpressionState &state); private: unordered_map keys; diff --git a/test/sql/compression/dict_compression.test b/test/sql/compression/dict_compression.test new file mode 100644 index 0000000..0c20545 --- /dev/null +++ b/test/sql/compression/dict_compression.test @@ -0,0 +1,28 @@ +# name: test/sql/bugfix_varchar_struct.test +# description: test simple_struct_encryption extension +# group: [simple_encryption] + +# Require statement will ensure this test is run with this extension loaded +require simple_encryption + +load __TEST_DIR__/test.db + +statement ok +create table tst(s struct(val blob)); + +statement ok +insert into tst select {'val':encode('01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678990123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234'|| cast(range >> 7 as string))} s from range(150000); + +statement ok +select octet_length(s.val) len from tst; + +query I +select compression, count(*) from pragma_storage_info('tst') where column_path='[0, 1]' group by 1; +---- +1 +1 +1 +1 + +#statement ok +#from pragma_storage_info('tst') where row_group_id=0; \ No newline at end of file diff --git a/test/sql/vectorized/vectorized_encrypt.test b/test/sql/vectorized/vectorized_encrypt.test new file mode 100644 index 0000000..8cf26ce --- /dev/null +++ b/test/sql/vectorized/vectorized_encrypt.test @@ -0,0 +1,26 @@ +# name: test/sql/vectorized/vectorized_encrypt.test +# description: Test vectorized encrypt scalar function +# group: [simple-encryption/vectorized] + +require simple_encryption + +# Ensure any currently stored secrets don't interfere with the test +statement ok +set allow_persistent_secrets=false; + +# Create an internal secret (for internal encryption of columns) +statement ok +CREATE SECRET key_1 ( + TYPE ENCRYPTION, + TOKEN '0123456789112345', + LENGTH 16 +); + +statement ok +CREATE TABLE test_1 AS SELECT 1 AS value FROM range(10000); + +statement ok +ALTER TABLE test_1 ADD COLUMN encrypted_values STRUCT(nonce_hi UBIGINT, nonce_lo UBIGINT, counter UINTEGER, cipher UINTEGER, value BLOB) DEFAULT (STRUCT_PACK(nonce_hi := 0, nonce_lo := 0, counter := 0, cipher := 0, BLOB := 0)); + +statement ok +UPDATE test_1 SET encrypted_values = encrypt_vectorized(value, 'key_1'); \ No newline at end of file