Skip to content

Commit

Permalink
Merge pull request #29 from ccfelius/refactor
Browse files Browse the repository at this point in the history
Refactor and add vectorized encryption first version
  • Loading branch information
ccfelius authored Dec 19, 2024
2 parents 015e976 + 38cc872 commit 8bc5749
Show file tree
Hide file tree
Showing 14 changed files with 304 additions and 281 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ set(EXTENSION_SOURCES
src/core/module.cpp
src/core/types.cpp
src/core/functions/scalar/encrypt.cpp
src/core/functions/scalar/encrypt_to_etype.cpp
src/core/functions/scalar/encrypt_naive.cpp
src/core/functions/scalar/encrypt_vectorized.cpp
src/core/functions/function_data/encrypt_function_data.cpp
src/core/functions/cast/varchar_cast.cpp
Expand Down
46 changes: 46 additions & 0 deletions experiments.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val uint128));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1, 'val': (range & 7) * (cast(1 as uint128) << 124) + (range >>2)} s from range(100000000);
from pragma_storage_info('tst') where row_group_id=90;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234'|| cast(range >> 3 as string))} s from range(100000000);
from tst limit 10;
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;
from pragma_storage_info('tst') where row_group_id=90;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456'|| cast(range >> 4 as string))} s from range(100000000);
from tst limit 10;
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890'|| cast(range >> 5 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678'|| cast(range >> 6 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

create or replace table tst(s struct(val blob));
insert into tst select {'val':encode('01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678990123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234'|| cast(range >> 7 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
from pragma_storage_info('tst') where row_group_id=0;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>11,'ctr': (range&2047)<<1, 'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567899012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456'|| cast(range >> 8 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

33 changes: 29 additions & 4 deletions src/core/functions/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,50 @@ SimpleEncryptionFunctionLocalState::SimpleEncryptionFunctionLocalState(ClientCon
iv[0] = iv[1] = 0;

// maybe generate iv_high also already in the bind
// allocate depending in sizeof(T) * items_in_vector
// maybe already in registering the function!
size_t data_size;
LogicalType type = bind_data->type;

// for now do 512 bytes
buffer_length = 512;
encryption_buffer = arena.Allocate(buffer_length);
// todo; fix this for all other types
if (type == LogicalType::VARCHAR) {
// allocate buffer for encrypted data
data_size = 512;
} else {
// maybe we can also just do per vector for certain types, so more then 128
data_size = GetTypeIdSize(type.InternalType()) * 128;
}

buffer_p = (data_ptr_t)encryption_buffer;
buffer_p = (data_ptr_t)arena.Allocate(data_size);

if (bind_data->type.id() == LogicalTypeId::VARCHAR) {
// allocate buffer for encrypted data
buffer_p = (data_ptr_t)arena.Allocate(128);
}
}

unique_ptr<FunctionLocalState>
SimpleEncryptionFunctionLocalState::Init(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) {
return make_uniq<SimpleEncryptionFunctionLocalState>(state.GetContext(), static_cast<EncryptFunctionData *>(bind_data));
}

SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::Get(ExpressionState &state) {
auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast<SimpleEncryptionFunctionLocalState>();
return local_state;
}

SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::ResetAndGet(ExpressionState &state) {
auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast<SimpleEncryptionFunctionLocalState>();
local_state.arena.Reset();
return local_state;
}

SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::AllocateAndGet(ExpressionState &state, idx_t buffer_size) {
auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast<SimpleEncryptionFunctionLocalState>();
local_state.arena.Allocate(buffer_size);
return local_state;
}

SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::ResetKeyAndGet(ExpressionState &state) {
auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast<SimpleEncryptionFunctionLocalState>();
local_state.arena.Reset();
Expand Down
10 changes: 8 additions & 2 deletions src/core/functions/function_data/encrypt_function_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct KeyData {
};

unique_ptr<FunctionData> EncryptFunctionData::Copy() const {
return make_uniq<EncryptFunctionData>(context, key_name);
return make_uniq<EncryptFunctionData>(context, key_name, type);
}

bool EncryptFunctionData::Equals(const FunctionData &other_p) const {
Expand Down Expand Up @@ -64,6 +64,12 @@ EncryptFunctionData::EncryptBind(ClientContext &context,
ScalarFunction &bound_function,
vector<unique_ptr<Expression>> &arguments) {

auto &value = arguments[0];

if (arguments.size() != 2) {
throw BinderException("Encrypt Scalar Function requires two arguments");
}

auto &key_child = arguments[1];
if (key_child->HasParameter()) {
throw ParameterNotResolvedException();
Expand All @@ -81,7 +87,7 @@ EncryptFunctionData::EncryptBind(ClientContext &context,

auto key_name = StringUtil::Lower(key_str);

return make_uniq<EncryptFunctionData>(context, key_name);
return make_uniq<EncryptFunctionData>(context, key_name, value->return_type);
}
} // namespace core
} // namespace simple_encryption
2 changes: 1 addition & 1 deletion src/core/functions/scalar/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set(EXTENSION_SOURCES
${EXTENSION_SOURCES}
${CMAKE_CURRENT_SOURCE_DIR}/encrypt.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encrypt_to_etype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encrypt_naive.cpp
PARENT_SCOPE
)
30 changes: 26 additions & 4 deletions src/core/functions/scalar/encrypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,38 @@
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
#include "duckdb/common/types/blob.hpp"
#include "duckdb/main/connection_manager.hpp"
#include "simple_encryption/core/functions/scalar/encrypt.hpp"
#include "simple_encryption/core/functions/scalar.hpp"
#include "simple_encryption_state.hpp"
#include "duckdb/common/encryption_state.hpp"
#include "duckdb/main/client_context.hpp"
#include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp"

#include "simple_encryption_state.hpp"
#include "simple_encryption/core/functions/scalar.hpp"
#include "simple_encryption/core/functions/scalar/encrypt.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"

namespace simple_encryption {
namespace core {

EncryptFunctionData& VCryptBasicFun::GetEncryptionBindInfo(ExpressionState &state) {
auto &func_expr = (BoundFunctionExpression &)state.expr;
return (EncryptFunctionData &)*func_expr.bind_info;
}

shared_ptr<SimpleEncryptionState>
VCryptBasicFun::GetSimpleEncryptionState(ExpressionState &state) {
auto &info = VCryptBasicFun::GetEncryptionBindInfo(state);
return info.context.registered_state->Get<SimpleEncryptionState>(
"simple_encryption");
}
// TODO; maybe pass by reference or so
string* VCryptBasicFun::GetKey(ExpressionState &state) {
auto &info = VCryptBasicFun::GetEncryptionBindInfo(state);
return &info.key;
}

shared_ptr<EncryptionState> VCryptBasicFun::GetEncryptionState(ExpressionState &state) {
return VCryptBasicFun::GetSimpleEncryptionState(state)->encryption_state;
}

shared_ptr<EncryptionUtil> GetEncryptionUtil(ExpressionState &state) {
auto &func_expr = (BoundFunctionExpression &)state.expr;
auto &info = (EncryptFunctionData &)*func_expr.bind_info;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,6 @@ GetSimpleEncryptionState(ExpressionState &state) {
"simple_encryption");
}

// TODO; maybe pass by reference or so
string* GetKey(ExpressionState &state) {
auto &info = GetEncryptionBindInfo(state);
return &info.key;
}

shared_ptr<EncryptionState> GetSimpleEncryptionStateLocal(ExpressionState &state) {
auto &info = GetEncryptionBindInfo(state);
// create a new local encryption state, but get the nonce etc. from the global state.
auto encryption_util = GetSimpleEncryptionState(state)->encryption_util;

return info.context.registered_state->Get<SimpleEncryptionState>(
"simple_encryption")->encryption_util->CreateEncryptionState();
}

bool HasSpace(shared_ptr<SimpleEncryptionState> simple_encryption_state,
uint64_t size) {
uint32_t max_value = ~0u;
Expand Down Expand Up @@ -210,10 +195,6 @@ bool CheckGeneratedKeySize(const uint32_t size){
}
}

shared_ptr<EncryptionState> GetEncryptionState(ExpressionState &state) {
return GetSimpleEncryptionState(state)->encryption_state;
}

// todo; template
LogicalType CreateEINTtypeStruct() {
return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT},
Expand All @@ -236,10 +217,10 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector,

// this is the global state
auto simple_encryption_state = GetSimpleEncryptionState(state);
auto encryption_state = GetEncryptionState(state);
auto encryption_state = VCryptBasicFun::GetEncryptionState(state);

// Get Key from Bind
auto key = GetKey(state);
auto key = VCryptBasicFun::GetKey(state);

// Reset the reference of the result vector
Vector struct_vector(result_struct, size);
Expand Down Expand Up @@ -295,10 +276,10 @@ void DecryptFromEtype(Vector &input_vector, uint64_t size,
auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state);
// global state
auto simple_encryption_state = GetSimpleEncryptionState(state);
auto encryption_state = GetEncryptionState(state);
auto encryption_state = VCryptBasicFun::GetEncryptionState(state);

// Get Key from Bind
auto key = GetKey(state);
auto key = VCryptBasicFun::GetKey(state);

using ENCRYPTED_TYPE = StructTypeTernary<uint64_t, uint64_t, T>;
using PLAINTEXT_TYPE = PrimitiveType<T>;
Expand Down
Loading

0 comments on commit 8bc5749

Please sign in to comment.