Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and add vectorized encryption first version #29

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
src/core/module.cpp
src/core/types.cpp
src/core/functions/scalar/encrypt.cpp
src/core/functions/scalar/encrypt_to_etype.cpp
src/core/functions/scalar/encrypt_naive.cpp
src/core/functions/scalar/encrypt_vectorized.cpp
src/core/functions/function_data/encrypt_function_data.cpp
src/core/functions/cast/varchar_cast.cpp
Expand Down
46 changes: 46 additions & 0 deletions experiments.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val uint128));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1, 'val': (range & 7) * (cast(1 as uint128) << 124) + (range >>2)} s from range(100000000);
from pragma_storage_info('tst') where row_group_id=90;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234'|| cast(range >> 3 as string))} s from range(100000000);
from tst limit 10;
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;
from pragma_storage_info('tst') where row_group_id=90;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456'|| cast(range >> 4 as string))} s from range(100000000);
from tst limit 10;
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890'|| cast(range >> 5 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678'|| cast(range >> 6 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

create or replace table tst(s struct(val blob));
insert into tst select {'val':encode('01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678990123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234'|| cast(range >> 7 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
from pragma_storage_info('tst') where row_group_id=0;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>11,'ctr': (range&2047)<<1, 'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567899012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456'|| cast(range >> 8 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

33 changes: 29 additions & 4 deletions src/core/functions/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,50 @@ SimpleEncryptionFunctionLocalState::SimpleEncryptionFunctionLocalState(ClientCon
iv[0] = iv[1] = 0;

// maybe generate iv_high also already in the bind
// allocate depending in sizeof(T) * items_in_vector
// maybe already in registering the function!
size_t data_size;
LogicalType type = bind_data->type;

// for now do 512 bytes
buffer_length = 512;
encryption_buffer = arena.Allocate(buffer_length);
// todo; fix this for all other types
if (type == LogicalType::VARCHAR) {
// allocate buffer for encrypted data
data_size = 512;
} else {
// maybe we can also just do per vector for certain types, so more then 128
data_size = GetTypeIdSize(type.InternalType()) * 128;
}

buffer_p = (data_ptr_t)encryption_buffer;
buffer_p = (data_ptr_t)arena.Allocate(data_size);

if (bind_data->type.id() == LogicalTypeId::VARCHAR) {
// allocate buffer for encrypted data
buffer_p = (data_ptr_t)arena.Allocate(128);
}
}

unique_ptr<FunctionLocalState>
SimpleEncryptionFunctionLocalState::Init(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) {
return make_uniq<SimpleEncryptionFunctionLocalState>(state.GetContext(), static_cast<EncryptFunctionData *>(bind_data));
}

SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::Get(ExpressionState &state) {
auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast<SimpleEncryptionFunctionLocalState>();
return local_state;
}

SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::ResetAndGet(ExpressionState &state) {
auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast<SimpleEncryptionFunctionLocalState>();
local_state.arena.Reset();
return local_state;
}

SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::AllocateAndGet(ExpressionState &state, idx_t buffer_size) {
auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast<SimpleEncryptionFunctionLocalState>();
local_state.arena.Allocate(buffer_size);
return local_state;
}

SimpleEncryptionFunctionLocalState &SimpleEncryptionFunctionLocalState::ResetKeyAndGet(ExpressionState &state) {
auto &local_state = ExecuteFunctionState::GetFunctionState(state)->Cast<SimpleEncryptionFunctionLocalState>();
local_state.arena.Reset();
Expand Down
10 changes: 8 additions & 2 deletions src/core/functions/function_data/encrypt_function_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct KeyData {
};

unique_ptr<FunctionData> EncryptFunctionData::Copy() const {
return make_uniq<EncryptFunctionData>(context, key_name);
return make_uniq<EncryptFunctionData>(context, key_name, type);
}

bool EncryptFunctionData::Equals(const FunctionData &other_p) const {
Expand Down Expand Up @@ -64,6 +64,12 @@ EncryptFunctionData::EncryptBind(ClientContext &context,
ScalarFunction &bound_function,
vector<unique_ptr<Expression>> &arguments) {

auto &value = arguments[0];

if (arguments.size() != 2) {
throw BinderException("Encrypt Scalar Function requires two arguments");
}

auto &key_child = arguments[1];
if (key_child->HasParameter()) {
throw ParameterNotResolvedException();
Expand All @@ -81,7 +87,7 @@ EncryptFunctionData::EncryptBind(ClientContext &context,

auto key_name = StringUtil::Lower(key_str);

return make_uniq<EncryptFunctionData>(context, key_name);
return make_uniq<EncryptFunctionData>(context, key_name, value->return_type);
}
} // namespace core
} // namespace simple_encryption
2 changes: 1 addition & 1 deletion src/core/functions/scalar/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set(EXTENSION_SOURCES
${EXTENSION_SOURCES}
${CMAKE_CURRENT_SOURCE_DIR}/encrypt.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encrypt_to_etype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encrypt_naive.cpp
PARENT_SCOPE
)
30 changes: 26 additions & 4 deletions src/core/functions/scalar/encrypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,38 @@
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
#include "duckdb/common/types/blob.hpp"
#include "duckdb/main/connection_manager.hpp"
#include "simple_encryption/core/functions/scalar/encrypt.hpp"
#include "simple_encryption/core/functions/scalar.hpp"
#include "simple_encryption_state.hpp"
#include "duckdb/common/encryption_state.hpp"
#include "duckdb/main/client_context.hpp"
#include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp"

#include "simple_encryption_state.hpp"
#include "simple_encryption/core/functions/scalar.hpp"
#include "simple_encryption/core/functions/scalar/encrypt.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"

namespace simple_encryption {
namespace core {

EncryptFunctionData& VCryptBasicFun::GetEncryptionBindInfo(ExpressionState &state) {
auto &func_expr = (BoundFunctionExpression &)state.expr;
return (EncryptFunctionData &)*func_expr.bind_info;
}

shared_ptr<SimpleEncryptionState>
VCryptBasicFun::GetSimpleEncryptionState(ExpressionState &state) {
auto &info = VCryptBasicFun::GetEncryptionBindInfo(state);
return info.context.registered_state->Get<SimpleEncryptionState>(
"simple_encryption");
}
// TODO; maybe pass by reference or so
string* VCryptBasicFun::GetKey(ExpressionState &state) {
auto &info = VCryptBasicFun::GetEncryptionBindInfo(state);
return &info.key;
}

shared_ptr<EncryptionState> VCryptBasicFun::GetEncryptionState(ExpressionState &state) {
return VCryptBasicFun::GetSimpleEncryptionState(state)->encryption_state;
}

shared_ptr<EncryptionUtil> GetEncryptionUtil(ExpressionState &state) {
auto &func_expr = (BoundFunctionExpression &)state.expr;
auto &info = (EncryptFunctionData &)*func_expr.bind_info;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,6 @@ GetSimpleEncryptionState(ExpressionState &state) {
"simple_encryption");
}

// TODO; maybe pass by reference or so
string* GetKey(ExpressionState &state) {
auto &info = GetEncryptionBindInfo(state);
return &info.key;
}

shared_ptr<EncryptionState> GetSimpleEncryptionStateLocal(ExpressionState &state) {
auto &info = GetEncryptionBindInfo(state);
// create a new local encryption state, but get the nonce etc. from the global state.
auto encryption_util = GetSimpleEncryptionState(state)->encryption_util;

return info.context.registered_state->Get<SimpleEncryptionState>(
"simple_encryption")->encryption_util->CreateEncryptionState();
}

bool HasSpace(shared_ptr<SimpleEncryptionState> simple_encryption_state,
uint64_t size) {
uint32_t max_value = ~0u;
Expand Down Expand Up @@ -210,10 +195,6 @@ bool CheckGeneratedKeySize(const uint32_t size){
}
}

shared_ptr<EncryptionState> GetEncryptionState(ExpressionState &state) {
return GetSimpleEncryptionState(state)->encryption_state;
}

// todo; template
LogicalType CreateEINTtypeStruct() {
return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT},
Expand All @@ -236,10 +217,10 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector,

// this is the global state
auto simple_encryption_state = GetSimpleEncryptionState(state);
auto encryption_state = GetEncryptionState(state);
auto encryption_state = VCryptBasicFun::GetEncryptionState(state);

// Get Key from Bind
auto key = GetKey(state);
auto key = VCryptBasicFun::GetKey(state);

// Reset the reference of the result vector
Vector struct_vector(result_struct, size);
Expand Down Expand Up @@ -295,10 +276,10 @@ void DecryptFromEtype(Vector &input_vector, uint64_t size,
auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state);
// global state
auto simple_encryption_state = GetSimpleEncryptionState(state);
auto encryption_state = GetEncryptionState(state);
auto encryption_state = VCryptBasicFun::GetEncryptionState(state);

// Get Key from Bind
auto key = GetKey(state);
auto key = VCryptBasicFun::GetKey(state);

using ENCRYPTED_TYPE = StructTypeTernary<uint64_t, uint64_t, T>;
using PLAINTEXT_TYPE = PrimitiveType<T>;
Expand Down
Loading
Loading