diff --git a/CMakeLists.txt b/CMakeLists.txt index 8879e92..543dbc3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,7 @@ set(EXTENSION_SOURCES src/core/module.cpp src/core/types.cpp src/core/functions/scalar/encrypt.cpp - src/core/functions/scalar/encrypt_to_etype.cpp + src/core/functions/scalar/encrypt_naive.cpp src/core/functions/scalar/encrypt_vectorized.cpp src/core/functions/function_data/encrypt_function_data.cpp src/core/functions/cast/varchar_cast.cpp diff --git a/experiments.sql b/experiments.sql new file mode 100644 index 0000000..b361111 --- /dev/null +++ b/experiments.sql @@ -0,0 +1,46 @@ +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val uint128)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1, 'val': (range & 7) * (cast(1 as uint128) << 124) + (range >>2)} s from range(100000000); +from pragma_storage_info('tst') where row_group_id=90; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234'|| cast(range >> 3 as string))} s from range(100000000); +from tst limit 10; +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; +from pragma_storage_info('tst') where row_group_id=90; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456'|| cast(range >> 4 as string))} s from range(100000000); +from tst limit 10; +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +from pragma_storage_info('tst') where row_group_id=90; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890'|| cast(range >> 5 as string))} s from range(100000000); +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +from pragma_storage_info('tst') where row_group_id=90; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678'|| cast(range >> 6 as string))} s from range(100000000); +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +from pragma_storage_info('tst') where row_group_id=90; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; + +create or replace table tst(s struct(val blob)); +insert into tst select {'val':encode('01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678990123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234'|| cast(range >> 7 as string))} s from range(100000000); +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +from pragma_storage_info('tst') where row_group_id=0; + +create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob)); +insert into tst select {'hi': 0,'lo': range>>11,'ctr': (range&2047)<<1, 'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567899012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456'|| cast(range >> 8 as string))} s from range(100000000); +select len,count(*) from (select octet_length(s.val) len from tst) t group by len; +select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1; +from pragma_storage_info('tst') where row_group_id=90; +select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1; + diff --git a/src/core/functions/common.cpp b/src/core/functions/common.cpp index 2afe06f..2cc1e00 100644 --- a/src/core/functions/common.cpp +++ b/src/core/functions/common.cpp @@ -12,12 +12,26 @@ SimpleEncryptionFunctionLocalState::SimpleEncryptionFunctionLocalState(ClientCon iv[0] = iv[1] = 0; // maybe generate iv_high also already in the bind + // allocate depending in sizeof(T) * items_in_vector + // maybe already in registering the function! + size_t data_size; + LogicalType type = bind_data->type; - // for now do 512 bytes - buffer_length = 512; - encryption_buffer = arena.Allocate(buffer_length); + // todo; fix this for all other types + if (type == LogicalType::VARCHAR) { + // allocate buffer for encrypted data + data_size = 512; + } else { + // maybe we can also just do per vector for certain types, so more then 128 + data_size = GetTypeIdSize(type.InternalType()) * 128; + } - buffer_p = (data_ptr_t)encryption_buffer; + buffer_p = (data_ptr_t)arena.Allocate(data_size); + + if (bind_data->type.id() == LogicalTypeId::VARCHAR) { + // allocate buffer for encrypted data + buffer_p = (data_ptr_t)arena.Allocate(128); + } } unique_ptr @@ -25,12 +39,23 @@ SimpleEncryptionFunctionLocalState::Init(ExpressionState &state, const BoundFunc return make_uniq(state.GetContext(), static_cast(bind_data)); } +SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::Get(ExpressionState &state) { + auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast(); + return local_state; +} + SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::ResetAndGet(ExpressionState &state) { auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast(); local_state.arena.Reset(); return local_state; } +SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::AllocateAndGet(ExpressionState &state, idx_t buffer_size) { + auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast(); + local_state.arena.Allocate(buffer_size); + return local_state; +} + SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::ResetKeyAndGet(ExpressionState &state) { auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast(); local_state.arena.Reset(); diff --git a/src/core/functions/function_data/encrypt_function_data.cpp b/src/core/functions/function_data/encrypt_function_data.cpp index c70f987..17b6837 100644 --- a/src/core/functions/function_data/encrypt_function_data.cpp +++ b/src/core/functions/function_data/encrypt_function_data.cpp @@ -15,7 +15,7 @@ struct KeyData { }; unique_ptr EncryptFunctionData::Copy() const { - return make_uniq(context, key_name); + return make_uniq(context, key_name, type); } bool EncryptFunctionData::Equals(const FunctionData &other_p) const { @@ -64,6 +64,12 @@ EncryptFunctionData::EncryptBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { + auto &value = arguments[0]; + + if (arguments.size() != 2) { + throw BinderException("Encrypt Scalar Function requires two arguments"); + } + auto &key_child = arguments[1]; if (key_child->HasParameter()) { throw ParameterNotResolvedException(); @@ -81,7 +87,7 @@ EncryptFunctionData::EncryptBind(ClientContext &context, auto key_name = StringUtil::Lower(key_str); - return make_uniq(context, key_name); + return make_uniq(context, key_name, value->return_type); } } // namespace core } // namespace simple_encryption diff --git a/src/core/functions/scalar/CMakeLists.txt b/src/core/functions/scalar/CMakeLists.txt index 80d083f..a2771cc 100644 --- a/src/core/functions/scalar/CMakeLists.txt +++ b/src/core/functions/scalar/CMakeLists.txt @@ -1,6 +1,6 @@ set(EXTENSION_SOURCES ${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/encrypt.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/encrypt_to_etype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/encrypt_naive.cpp PARENT_SCOPE ) \ No newline at end of file diff --git a/src/core/functions/scalar/encrypt.cpp b/src/core/functions/scalar/encrypt.cpp index 30dfa46..a894d53 100644 --- a/src/core/functions/scalar/encrypt.cpp +++ b/src/core/functions/scalar/encrypt.cpp @@ -15,16 +15,38 @@ #include #include "duckdb/common/types/blob.hpp" #include "duckdb/main/connection_manager.hpp" -#include "simple_encryption/core/functions/scalar/encrypt.hpp" -#include "simple_encryption/core/functions/scalar.hpp" -#include "simple_encryption_state.hpp" +#include "duckdb/common/encryption_state.hpp" #include "duckdb/main/client_context.hpp" -#include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp" + +#include "simple_encryption_state.hpp" +#include "simple_encryption/core/functions/scalar.hpp" +#include "simple_encryption/core/functions/scalar/encrypt.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" namespace simple_encryption { namespace core { +EncryptFunctionData& VCryptBasicFun::GetEncryptionBindInfo(ExpressionState &state) { + auto &func_expr = (BoundFunctionExpression &)state.expr; + return (EncryptFunctionData &)*func_expr.bind_info; +} + +shared_ptr +VCryptBasicFun::GetSimpleEncryptionState(ExpressionState &state) { + auto &info = VCryptBasicFun::GetEncryptionBindInfo(state); + return info.context.registered_state->Get( + "simple_encryption"); +} +// TODO; maybe pass by reference or so +string* VCryptBasicFun::GetKey(ExpressionState &state) { + auto &info = VCryptBasicFun::GetEncryptionBindInfo(state); + return &info.key; +} + +shared_ptr VCryptBasicFun::GetEncryptionState(ExpressionState &state) { + return VCryptBasicFun::GetSimpleEncryptionState(state)->encryption_state; +} + shared_ptr GetEncryptionUtil(ExpressionState &state) { auto &func_expr = (BoundFunctionExpression &)state.expr; auto &info = (EncryptFunctionData &)*func_expr.bind_info; diff --git a/src/core/functions/scalar/encrypt_to_etype.cpp b/src/core/functions/scalar/encrypt_naive.cpp similarity index 94% rename from src/core/functions/scalar/encrypt_to_etype.cpp rename to src/core/functions/scalar/encrypt_naive.cpp index f4b0cba..130d852 100644 --- a/src/core/functions/scalar/encrypt_to_etype.cpp +++ b/src/core/functions/scalar/encrypt_naive.cpp @@ -167,21 +167,6 @@ GetSimpleEncryptionState(ExpressionState &state) { "simple_encryption"); } -// TODO; maybe pass by reference or so -string* GetKey(ExpressionState &state) { - auto &info = GetEncryptionBindInfo(state); - return &info.key; -} - -shared_ptr GetSimpleEncryptionStateLocal(ExpressionState &state) { - auto &info = GetEncryptionBindInfo(state); - // create a new local encryption state, but get the nonce etc. from the global state. - auto encryption_util = GetSimpleEncryptionState(state)->encryption_util; - - return info.context.registered_state->Get( - "simple_encryption")->encryption_util->CreateEncryptionState(); -} - bool HasSpace(shared_ptr simple_encryption_state, uint64_t size) { uint32_t max_value = ~0u; @@ -210,10 +195,6 @@ bool CheckGeneratedKeySize(const uint32_t size){ } } -shared_ptr GetEncryptionState(ExpressionState &state) { - return GetSimpleEncryptionState(state)->encryption_state; -} - // todo; template LogicalType CreateEINTtypeStruct() { return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, @@ -236,10 +217,10 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector, // this is the global state auto simple_encryption_state = GetSimpleEncryptionState(state); - auto encryption_state = GetEncryptionState(state); + auto encryption_state = VCryptBasicFun::GetEncryptionState(state); // Get Key from Bind - auto key = GetKey(state); + auto key = VCryptBasicFun::GetKey(state); // Reset the reference of the result vector Vector struct_vector(result_struct, size); @@ -295,10 +276,10 @@ void DecryptFromEtype(Vector &input_vector, uint64_t size, auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state); // global state auto simple_encryption_state = GetSimpleEncryptionState(state); - auto encryption_state = GetEncryptionState(state); + auto encryption_state = VCryptBasicFun::GetEncryptionState(state); // Get Key from Bind - auto key = GetKey(state); + auto key = VCryptBasicFun::GetKey(state); using ENCRYPTED_TYPE = StructTypeTernary; using PLAINTEXT_TYPE = PrimitiveType; diff --git a/src/core/functions/scalar/encrypt_vectorized.cpp b/src/core/functions/scalar/encrypt_vectorized.cpp index 35acb3a..d53c010 100644 --- a/src/core/functions/scalar/encrypt_vectorized.cpp +++ b/src/core/functions/scalar/encrypt_vectorized.cpp @@ -17,8 +17,6 @@ #include #include "simple_encryption_state.hpp" -#include "simple_encryption/core/types.hpp" -#include "simple_encryption/core/crypto/crypto_primitives.hpp" #include "simple_encryption/core/functions/common.hpp" #include "simple_encryption/core/functions/scalar.hpp" #include "simple_encryption/core/functions/secrets.hpp" @@ -29,213 +27,96 @@ namespace simple_encryption { namespace core { -template -typename std::enable_if< - std::is_integral::value || std::is_floating_point::value, T>::type -ProcessVectorizedEncrypt(shared_ptr encryption_state, - Vector &result, T plaintext_data, uint8_t *buffer_p) { - T encrypted_data; - encryption_state->Process( - reinterpret_cast(&plaintext_data), sizeof(int32_t), - reinterpret_cast(&encrypted_data), sizeof(int32_t)); - return encrypted_data; -} - - -template -typename std::enable_if::value, T>::type -ProcessVectorizedEncrypt(shared_ptr encryption_state, - Vector &result, T plaintext_data, uint8_t *buffer_p) { - - auto &children = StructVector::GetEntries(result); - auto &result_vector = children[2]; - - // first encrypt the bytes of the string into a temp buffer_p - auto input_data = data_ptr_t(plaintext_data.GetData()); - auto value_size = plaintext_data.GetSize(); - - encryption_state->Process(input_data, value_size, buffer_p, value_size); - - // Convert the encrypted data to a BLOB - auto encrypted_data = - string_t(reinterpret_cast(buffer_p), value_size); - size_t base64_size = Blob::ToBase64Size(encrypted_data); - - // convert to Base64 into a newly allocated string in the result vector - T base64_data = StringVector::EmptyString(*result_vector, base64_size); - memset(base64_data.GetDataWriteable(), 0, 12); - Blob::ToBase64(encrypted_data, base64_data.GetDataWriteable()); - - return base64_data; -} - - -template -typename std::enable_if::value, T>::type -ProcessVectorizedDecrypt(shared_ptr encryption_state, - Vector &result, T base64_data, uint8_t *buffer_p) { - - // we cann just fix the blob_size - // first encrypt the bytes of the string into a temp buffer_p - size_t encrypted_size = Blob::FromBase64Size(base64_data); - size_t decrypted_size = encrypted_size; - Blob::FromBase64(base64_data, reinterpret_cast(buffer_p), - encrypted_size); - - D_ASSERT(encrypted_size <= base64_data.GetSize()); - - string_t decrypted_data = - StringVector::EmptyString(result, decrypted_size); - encryption_state->Process( - buffer_p, encrypted_size, - reinterpret_cast(decrypted_data.GetDataWriteable()), - decrypted_size); - - return decrypted_data; -} - -template -typename std::enable_if< - std::is_integral::value || std::is_floating_point::value, T>::type -ProcessVectorizedDecrypt(shared_ptr encryption_state, - Vector &result, T encrypted_data, uint8_t *buffer_p) { - T decrypted_data; - encryption_state->Process( - reinterpret_cast(&encrypted_data), sizeof(T), - reinterpret_cast(&decrypted_data), sizeof(T)); - return decrypted_data; -} - -EncryptFunctionData &GetEncryptionBindInfo(ExpressionState &state) { - auto &func_expr = (BoundFunctionExpression &)state.expr; - return (EncryptFunctionData &)*func_expr.bind_info; -} - -shared_ptr -GetSimpleEncryptionState(ExpressionState &state) { - - auto &info = GetEncryptionBindInfo(state); - return info.context.registered_state->Get( - "simple_encryption"); -} - -// TODO; maybe pass by reference or so -string* GetKey(ExpressionState &state) { - auto &info = GetEncryptionBindInfo(state); - return &info.key; -} - -shared_ptr GetSimpleEncryptionStateLocal(ExpressionState &state) { - auto &info = GetEncryptionBindInfo(state); - // create a new local encryption state, but get the nonce etc. from the global state. - auto encryption_util = GetSimpleEncryptionState(state)->encryption_util; - - return info.context.registered_state->Get( - "simple_encryption")->encryption_util->CreateEncryptionState(); -} - -bool HasSpace(shared_ptr simple_encryption_state, - uint64_t size) { - uint32_t max_value = ~0u; - if ((max_value - simple_encryption_state->counter) > size) { - return true; - } - return false; -} - - -void SetIV(shared_ptr simple_encryption_state) { - simple_encryption_state->iv[1] = 0; - simple_encryption_state->encryption_state->GenerateRandomData( - reinterpret_cast(simple_encryption_state->iv), 12); -} - -bool CheckGeneratedKeySize(const uint32_t size){ - - switch(size){ - case 16: - case 24: - case 32: - return true; - default: - return false; - } -} - -shared_ptr GetEncryptionState(ExpressionState &state) { - return GetSimpleEncryptionState(state)->encryption_state; -} - -// todo; template -LogicalType CreateEINTtypeStruct() { +LogicalType CreateEncryptionStruct() { return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, {"nonce_lo", LogicalType::UBIGINT}, - {"value", LogicalType::INTEGER}}); -} - -LogicalType CreateEVARtypeStruct() { - return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, - {"nonce_lo", LogicalType::UBIGINT}, - {"value", LogicalType::VARCHAR}}); + {"counter", LogicalType::UINTEGER}, + {"cipher", LogicalType::TINYINT}, + {"value", LogicalType::BLOB}}); } template -void EncryptToEtype(LogicalType result_struct, Vector &input_vector, - uint64_t size, ExpressionState &state, - Vector &result) { +void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, Vector &result) { + // local, global and encryption state auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state); + auto simple_encryption_state = VCryptBasicFun::GetSimpleEncryptionState(state); + auto encryption_state = VCryptBasicFun::GetEncryptionState(state); + auto key = VCryptBasicFun::GetKey(state); - // this is the global state - auto simple_encryption_state = GetSimpleEncryptionState(state); - auto encryption_state = GetEncryptionState(state); - - // Get Key from Bind - auto key = GetKey(state); - - // Reset the reference of the result vector - Vector struct_vector(result_struct, size); + Vector struct_vector(CreateEncryptionStruct(), size); result.ReferenceAndSetType(struct_vector); - // ValidityMask &result_validity = FlatVector::Validity(result); - - if ((simple_encryption_state->counter == 0) || (HasSpace(simple_encryption_state, size) == false)) { - // generate new random IV and reset counter (if strart or if there is no space left) - SetIV(simple_encryption_state); - simple_encryption_state->counter = 0; - } - auto &children = StructVector::GetEntries(result); auto &nonce_hi = children[0]; - nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR); - - auto nonce_lo = simple_encryption_state->iv[1]; - - using ENCRYPTED_TYPE = StructTypeTernary; - using PLAINTEXT_TYPE = PrimitiveType; - - encryption_state->InitializeEncryption( - reinterpret_cast(simple_encryption_state->iv), 16, - reinterpret_cast(key)); - - GenericExecutor::ExecuteUnary( - input_vector, result, size, [&](PLAINTEXT_TYPE input) { - - simple_encryption_state->iv[1]++; - simple_encryption_state->counter++; - - encryption_state->InitializeEncryption( - reinterpret_cast(simple_encryption_state->iv), 16, - reinterpret_cast(key)); - - T encrypted_data = - ProcessVectorizedEncrypt(encryption_state, result, input.val, - lstate.buffer_p); + auto &nonce_lo = children[1]; + auto &counter_vec = children[2]; + auto &cipher_vec = children[3]; - return ENCRYPTED_TYPE{simple_encryption_state->iv[0], - simple_encryption_state->iv[1], encrypted_data}; - }); + // result vector containing encrypted data + auto &blob = children[4]; + // set the constant vectors + nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR); + nonce_lo->SetVectorType(VectorType::CONSTANT_VECTOR); + + auto nonce_hi_64 = simple_encryption_state->iv[0]; + auto nonce_lo_32 = simple_encryption_state->iv[0]; + + // is not the pointer but really the actual value copied? + // Set constant vectors to a single value + nonce_hi->Reference(Value::UBIGINT(nonce_hi_64)); + nonce_hi->Reference(Value::UBIGINT(nonce_lo_32)); + + counter_vec->SetVectorType(VectorType::FLAT_VECTOR); + cipher_vec->SetVectorType(VectorType::FLAT_VECTOR); + + // Set the blob vector to dict vector for compressed execution + blob->SetVectorType(VectorType::DICTIONARY_VECTOR); + + encryption_state->InitializeEncryption(reinterpret_cast(simple_encryption_state->iv), 16, key); + + auto &blob_child = DictionaryVector::Child(*blob); + auto &blob_sel = DictionaryVector::SelVector(*blob); + + // we process in batches of 128 values, or we can do it with all and cut at each 128 * sizeof(T) bits (only works for similar lengths) + // fill 512 bytes, so 512 / sizeof(T) values and at least 128 values. + // note: this only works for fixed-size types + auto batch_size = 128 * sizeof(T); + auto total_size = sizeof(T) * size; + // and the cipher + + // todo: assign buffer_p with the right size + encryption_state->Process(reinterpret_cast(input_vector), total_size, lstate.buffer_p, total_size); + + auto index = 0; + auto batch_nr = 0; + // get counter from local state + uint32_t counter = 0; + uint8_t cipher = 0; + const size_t step = sizeof(T) / 16; + uint64_t buffer_offset; + + // TODO: for strings this all works slighly different + for(int i = 0; i + 128; i < (DEFAULT_STANDARD_VECTOR_SIZE / 128)){ + + buffer_offset = batch_nr * sizeof(T) * 128; + // Allocate space in the dictionary vector (i.e. blob_child) + string_t batch_data = StringVector::EmptyString(blob_child, batch_size); // value size + *(uint32_t*) batch_data.GetPrefixWriteable() = *(uint32_t *) lstate.buffer_p + buffer_offset; + memcpy(batch_data.GetDataWriteable(), lstate.buffer_p, batch_size); + + // set index in selection vector + for (int j = 0; j++; j < 128){ + cipher = j % step; + counter += (cipher == 0 && index != 0) ? 1 : 0; + cipher_vec->SetValue(index, Value::TINYINT(cipher)); + counter_vec->SetValue(index, Value::UINTEGER(counter)); + blob_sel.set_index(index, batch_nr); + index++; + } + batch_nr++; + } } @@ -246,72 +127,77 @@ void DecryptFromEtype(Vector &input_vector, uint64_t size, // local state (contains key, buffer, iv etc.) auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state); // global state - auto simple_encryption_state = GetSimpleEncryptionState(state); - auto encryption_state = GetEncryptionState(state); + auto simple_encryption_state = VCryptBasicFun::GetSimpleEncryptionState(state); + auto encryption_state = VCryptBasicFun::GetEncryptionState(state); // Get Key from Bind - auto key = GetKey(state); - - using ENCRYPTED_TYPE = StructTypeTernary; - using PLAINTEXT_TYPE = PrimitiveType; - - GenericExecutor::ExecuteUnary( - input_vector, result, size, [&](ENCRYPTED_TYPE input) { - simple_encryption_state->iv[0] = input.a_val; - simple_encryption_state->iv[1] = input.b_val; - - encryption_state->InitializeDecryption( - reinterpret_cast(simple_encryption_state->iv), 12, - reinterpret_cast(key)); - - T decrypted_data = - ProcessVectorizedDecrypt(encryption_state, result, input.c_val, - lstate.buffer_p); - return decrypted_data; - }); + auto key = VCryptBasicFun::GetKey(state); + +// using ENCRYPTED_TYPE = StructTypeTernary; +// using PLAINTEXT_TYPE = PrimitiveType; +// +// GenericExecutor::ExecuteUnary( +// input_vector, result, size, [&](ENCRYPTED_TYPE input) { +// simple_encryption_state->iv[0] = input.a_val; +// simple_encryption_state->iv[1] = input.b_val; +// +// encryption_state->InitializeDecryption( +// reinterpret_cast(simple_encryption_state->iv), 12, +// reinterpret_cast(key)); +// +// T decrypted_data = +// ProcessVectorizedDecrypt(encryption_state, result, input.c_val, +// lstate.buffer_p); +// return decrypted_data; +// }); } -static void EncryptDataToEtype(DataChunk &args, ExpressionState &state, +static void EncryptDataVectorized(DataChunk &args, ExpressionState &state, Vector &result) { auto &input_vector = args.data[0]; auto vector_type = input_vector.GetType(); auto size = args.size(); + UnifiedVectorFormat vdata_input; + input_vector.ToUnifiedFormat(args.size(), vdata_input); + ValidityMask &result_validity = FlatVector::Validity(result); + auto vd = vdata_input.data; + if (vector_type.IsNumeric()) { switch (vector_type.id()) { case LogicalTypeId::TINYINT: case LogicalTypeId::UTINYINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((int8_t *)vdata_input.data, size, state, result); case LogicalTypeId::SMALLINT: case LogicalTypeId::USMALLINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((int16_t *)vdata_input.data, size, state, result); case LogicalTypeId::INTEGER: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((int32_t *)vdata_input.data, size, state, result); case LogicalTypeId::UINTEGER: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((uint32_t *)vdata_input.data, size, state, result); case LogicalTypeId::BIGINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((int64_t *)vdata_input.data, size, state, result); case LogicalTypeId::UBIGINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((uint64_t *)vdata_input.data, size, state, result); case LogicalTypeId::FLOAT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((float *)vdata_input.data, size, state, result); case LogicalTypeId::DOUBLE: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + return EncryptVectorized((double *)vdata_input.data, size, state, result); default: throw NotImplementedException("Unsupported numeric type for encryption"); } } else if (vector_type.id() == LogicalTypeId::VARCHAR) { - return EncryptToEtype(CreateEVARtypeStruct(), input_vector, + return EncryptVectorized((string_t *)vdata_input.data, size, state, result); } else if (vector_type.IsNested()) { throw NotImplementedException( @@ -323,7 +209,7 @@ static void EncryptDataToEtype(DataChunk &args, ExpressionState &state, } -static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, +static void DecryptDataVectorized(DataChunk &args, ExpressionState &state, Vector &result) { auto size = args.size(); @@ -372,7 +258,7 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, } } -ScalarFunctionSet GetEncryptionStructFunction() { +ScalarFunctionSet GetEncryptionVectorizedFunction() { ScalarFunctionSet set("encrypt_vectorized"); for (auto &type : LogicalType::AllTypes()) { @@ -381,13 +267,13 @@ ScalarFunctionSet GetEncryptionStructFunction() { LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, {"nonce_lo", LogicalType::UBIGINT}, {"value", type}}), - EncryptDataToEtype, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init)); + EncryptDataVectorized, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init)); } return set; } -ScalarFunctionSet GetDecryptionStructFunction() { +ScalarFunctionSet GetDecryptionVectorizedFunction() { ScalarFunctionSet set("decrypt_vectorized"); for (auto &type : LogicalType::AllTypes()) { @@ -398,7 +284,7 @@ ScalarFunctionSet GetDecryptionStructFunction() { {"nonce_lo", nonce_type_b}, {"value", type}}), LogicalType::VARCHAR}, - type, DecryptDataFromEtype, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init)); + type, DecryptDataVectorized, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init)); } } @@ -416,10 +302,10 @@ ScalarFunctionSet GetDecryptionStructFunction() { // Register functions //------------------------------------------------------------------------------ -void CoreScalarFunctions::RegisterEncryptDataStructScalarFunction( +void CoreScalarFunctions::RegisterEncryptVectorizedScalarFunction( DatabaseInstance &db) { - ExtensionUtil::RegisterFunction(db, GetEncryptionStructFunction()); - ExtensionUtil::RegisterFunction(db, GetDecryptionStructFunction()); + ExtensionUtil::RegisterFunction(db, GetEncryptionVectorizedFunction()); + ExtensionUtil::RegisterFunction(db, GetDecryptionVectorizedFunction()); } } // namespace core } // namespace simple_encryption diff --git a/src/include/simple_encryption/core/functions/common.hpp b/src/include/simple_encryption/core/functions/common.hpp index c98f795..6e041ca 100644 --- a/src/include/simple_encryption/core/functions/common.hpp +++ b/src/include/simple_encryption/core/functions/common.hpp @@ -10,9 +10,8 @@ struct SimpleEncryptionFunctionLocalState : FunctionLocalState { public: ArenaAllocator arena; - - idx_t buffer_length; uint64_t iv[2]; + uint32_t counter = 0; // todo: key can be 16, 24 or 32 unsigned char key[16]; @@ -25,7 +24,9 @@ struct SimpleEncryptionFunctionLocalState : FunctionLocalState { explicit SimpleEncryptionFunctionLocalState(ClientContext &context, EncryptFunctionData *bind_data); static unique_ptr Init(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data); + static SimpleEncryptionFunctionLocalState &Get(ExpressionState &state); static SimpleEncryptionFunctionLocalState &ResetAndGet(ExpressionState &state); + static SimpleEncryptionFunctionLocalState &AllocateAndGet(ExpressionState &state, idx_t buffer_size); static SimpleEncryptionFunctionLocalState &ResetKeyAndGet(ExpressionState &state); }; diff --git a/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp b/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp index 82ba286..5d2462c 100644 --- a/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp +++ b/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp @@ -13,9 +13,10 @@ struct EncryptFunctionData : FunctionData { // Save the Key string key_name; string key; + LogicalType type; // BoundStatement relation; - EncryptFunctionData(ClientContext &context, string key_name) : context(context), key_name(key_name) { + EncryptFunctionData(ClientContext &context, string key_name, LogicalType type) : context(context), key_name(key_name), type(type) { // generate encryption key and store key = GetKeyFromSecret(context, key_name); } diff --git a/src/include/simple_encryption/core/functions/scalar.hpp b/src/include/simple_encryption/core/functions/scalar.hpp index 4c12e97..02453a6 100644 --- a/src/include/simple_encryption/core/functions/scalar.hpp +++ b/src/include/simple_encryption/core/functions/scalar.hpp @@ -10,11 +10,13 @@ struct CoreScalarFunctions { static void Register(duckdb::DatabaseInstance &db) { RegisterEncryptDataScalarFunction(db); RegisterEncryptDataStructScalarFunction(db); + RegisterEncryptVectorizedScalarFunction(db); } private: static void RegisterEncryptDataScalarFunction(duckdb::DatabaseInstance &db); static void RegisterEncryptDataStructScalarFunction(duckdb::DatabaseInstance &db); + static void RegisterEncryptVectorizedScalarFunction(duckdb::DatabaseInstance &db); }; } // namespace core diff --git a/src/include/simple_encryption/core/functions/scalar/encrypt.hpp b/src/include/simple_encryption/core/functions/scalar/encrypt.hpp index 92a6601..dc08f78 100644 --- a/src/include/simple_encryption/core/functions/scalar/encrypt.hpp +++ b/src/include/simple_encryption/core/functions/scalar/encrypt.hpp @@ -2,6 +2,7 @@ #include "simple_encryption/common.hpp" #include "duckdb/common/encryption_state.hpp" +#include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp" #ifndef DUCKDB_AMALGAMATION #include "duckdb/storage/object_cache.hpp" @@ -11,19 +12,17 @@ namespace simple_encryption { namespace core { -class SimpleEncryptKeys : public ObjectCacheEntry { +class VCryptBasicFun { public: - static SimpleEncryptKeys &Get(ClientContext &context); + // Fix this later + static VCryptBasicFun &Get(ClientContext &context); public: - void AddKey(const string &key_name, const string &key); - bool HasKey(const string &key_name) const; - const string &GetKey(const string &key_name) const; - -public: - static string ObjectType(); - string GetObjectType() override; + static string* GetKey(ExpressionState &state); + static EncryptFunctionData &GetEncryptionBindInfo(ExpressionState &state); + static shared_ptr GetSimpleEncryptionState(ExpressionState &state); + static shared_ptr GetEncryptionState(ExpressionState &state); private: unordered_map keys; diff --git a/test/sql/compression/dict_compression.test b/test/sql/compression/dict_compression.test new file mode 100644 index 0000000..0c20545 --- /dev/null +++ b/test/sql/compression/dict_compression.test @@ -0,0 +1,28 @@ +# name: test/sql/bugfix_varchar_struct.test +# description: test simple_struct_encryption extension +# group: [simple_encryption] + +# Require statement will ensure this test is run with this extension loaded +require simple_encryption + +load __TEST_DIR__/test.db + +statement ok +create table tst(s struct(val blob)); + +statement ok +insert into tst select {'val':encode('01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678990123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234'|| cast(range >> 7 as string))} s from range(150000); + +statement ok +select octet_length(s.val) len from tst; + +query I +select compression, count(*) from pragma_storage_info('tst') where column_path='[0, 1]' group by 1; +---- +1 +1 +1 +1 + +#statement ok +#from pragma_storage_info('tst') where row_group_id=0; \ No newline at end of file diff --git a/test/sql/vectorized/vectorized_encrypt.test b/test/sql/vectorized/vectorized_encrypt.test new file mode 100644 index 0000000..8cf26ce --- /dev/null +++ b/test/sql/vectorized/vectorized_encrypt.test @@ -0,0 +1,26 @@ +# name: test/sql/vectorized/vectorized_encrypt.test +# description: Test vectorized encrypt scalar function +# group: [simple-encryption/vectorized] + +require simple_encryption + +# Ensure any currently stored secrets don't interfere with the test +statement ok +set allow_persistent_secrets=false; + +# Create an internal secret (for internal encryption of columns) +statement ok +CREATE SECRET key_1 ( + TYPE ENCRYPTION, + TOKEN '0123456789112345', + LENGTH 16 +); + +statement ok +CREATE TABLE test_1 AS SELECT 1 AS value FROM range(10000); + +statement ok +ALTER TABLE test_1 ADD COLUMN encrypted_values STRUCT(nonce_hi UBIGINT, nonce_lo UBIGINT, counter UINTEGER, cipher UINTEGER, value BLOB) DEFAULT (STRUCT_PACK(nonce_hi := 0, nonce_lo := 0, counter := 0, cipher := 0, BLOB := 0)); + +statement ok +UPDATE test_1 SET encrypted_values = encrypt_vectorized(value, 'key_1'); \ No newline at end of file