Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ccfelius committed Dec 18, 2024
1 parent 321f657 commit 3b658a1
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 44 deletions.
46 changes: 46 additions & 0 deletions experiments.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val uint128));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1, 'val': (range & 7) * (cast(1 as uint128) << 124) + (range >>2)} s from range(100000000);
from pragma_storage_info('tst') where row_group_id=90;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234'|| cast(range >> 3 as string))} s from range(100000000);
from tst limit 10;
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;
from pragma_storage_info('tst') where row_group_id=90;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456'|| cast(range >> 4 as string))} s from range(100000000);
from tst limit 10;
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890'|| cast(range >> 5 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>15,'ctr': (range&32767)<<1,'val':encode('012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678'|| cast(range >> 6 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

create or replace table tst(s struct(val blob));
insert into tst select {'val':encode('01234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678990123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234'|| cast(range >> 7 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
from pragma_storage_info('tst') where row_group_id=0;

create or replace table tst(s struct(hi uint64, lo uint32, ctr uint16, val blob));
insert into tst select {'hi': 0,'lo': range>>11,'ctr': (range&2047)<<1, 'val':encode('0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567899012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456'|| cast(range >> 8 as string))} s from range(100000000);
select len,count(*) from (select octet_length(s.val) len from tst) t group by len;
select cnt,count(*) from (select s.val,count(*) cnt from tst group by 1) group by 1;
from pragma_storage_info('tst') where row_group_id=90;
select compression,count(*) from pragma_storage_info('tst') where column_path='[0, 4]' group by 1;

103 changes: 59 additions & 44 deletions src/core/functions/scalar/encrypt_vectorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,12 @@ shared_ptr<EncryptionState> GetEncryptionState(ExpressionState &state) {
}

// todo; template
LogicalType CreateEINTtypeStruct() {
LogicalType CreateBLOBStruct() {
return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT},
{"nonce_lo", LogicalType::UBIGINT},
{"value", LogicalType::INTEGER}});
{"counter", LogicalType::UINTEGER},
{"cipher", LogicalType::TINYINT},
{"value", LogicalType::BLOB}});
}

LogicalType CreateEVARtypeStruct() {
Expand All @@ -180,61 +182,60 @@ LogicalType CreateEVARtypeStruct() {
}

template <typename T>
void EncryptToEtype(LogicalType result_struct, Vector &input_vector,
uint32_t CalculateBlockCounter(uint32_t counter) {
return ceil(counter * sizeof(T) / 16);
}

template <typename T>
void EncryptVectorized(LogicalType result_struct, Vector &input_vector,
uint64_t size, ExpressionState &state,
Vector &result) {

// local, global and encryption state
auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state);

// this is the global state
auto simple_encryption_state = GetSimpleEncryptionState(state);
auto encryption_state = GetEncryptionState(state);

// Get Key from Bind
auto key = GetKey(state);

// Reset the reference of the result vector
Vector struct_vector(result_struct, size);
Vector struct_vector(CreateBLOBStruct(), size);
result.ReferenceAndSetType(struct_vector);

// ValidityMask &result_validity = FlatVector::Validity(result);

if ((simple_encryption_state->counter == 0) || (HasSpace(simple_encryption_state, size) == false)) {
// generate new random IV and reset counter (if strart or if there is no space left)
SetIV(simple_encryption_state);
simple_encryption_state->counter = 0;
}

auto &children = StructVector::GetEntries(result);
auto &nonce_hi = children[0];
auto &nonce_lo = children[1];
auto &counter = children[2];
auto &cipher = children[3];

// result vector containing encrypted data
auto &blob = children[4];

nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR);
nonce_lo->SetVectorType(VectorType::CONSTANT_VECTOR);

auto nonce_lo = simple_encryption_state->iv[1];
// put counter also in the local state
counter->SetVectorType(VectorType::FLAT_VECTOR);
cipher->SetVectorType(VectorType::FLAT_VECTOR);

using ENCRYPTED_TYPE = StructTypeTernary<uint64_t, uint64_t, T>;
using PLAINTEXT_TYPE = PrimitiveType<T>;
// auto nonce_lo = simple_encryption_state->iv[1];

encryption_state->InitializeEncryption(
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16,
reinterpret_cast<const string *>(key));
// using ENCRYPTED_TYPE = StructTypeTernary<uint64_t, uint64_t, T>;
// using PLAINTEXT_TYPE = PrimitiveType<T>;
// but how do we increase the counter?

GenericExecutor::ExecuteUnary<PLAINTEXT_TYPE, ENCRYPTED_TYPE>(
input_vector, result, size, [&](PLAINTEXT_TYPE input) {

simple_encryption_state->iv[1]++;
simple_encryption_state->counter++;

encryption_state->InitializeEncryption(
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16,
reinterpret_cast<const string *>(key));
encryption_state->InitializeEncryption(
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16,
key);

T encrypted_data =
ProcessVectorizedEncrypt(encryption_state, result, input.val,
lstate.buffer_p);
T encrypted_data =
ProcessVectorizedEncrypt(encryption_state, result, input.val,
lstate.buffer_p);

return ENCRYPTED_TYPE{simple_encryption_state->iv[0],
simple_encryption_state->iv[1], encrypted_data};
});
});

}

Expand Down Expand Up @@ -272,46 +273,60 @@ void DecryptFromEtype(Vector &input_vector, uint64_t size,
}


static void EncryptDataToEtype(DataChunk &args, ExpressionState &state,
static void EncryptDataVectorized(DataChunk &args, ExpressionState &state,
Vector &result) {

auto &input_vector = args.data[0];
auto vector_type = input_vector.GetType();
auto size = args.size();

// get src and dst vectors for searches
auto &src = args.data[2];
auto &dst = args.data[3];

UnifiedVectorFormat vdata_src;
UnifiedVectorFormat vdata_dst;

src.ToUnifiedFormat(args.size(), vdata_src);
dst.ToUnifiedFormat(args.size(), vdata_dst);
auto src_data = (int64_t *)vdata_src.data;
auto dst_data = (int64_t *)vdata_dst.data;

ValidityMask &result_validity = FlatVector::Validity(result);

if (vector_type.IsNumeric()) {
switch (vector_type.id()) {
case LogicalTypeId::TINYINT:
case LogicalTypeId::UTINYINT:
return EncryptToEtype<int8_t>(CreateEINTtypeStruct(), input_vector,
return EncryptVectorized<int8_t>(CreateEINTtypeStruct(), input_vector,
size, state, result);
case LogicalTypeId::SMALLINT:
case LogicalTypeId::USMALLINT:
return EncryptToEtype<int16_t>(CreateEINTtypeStruct(), input_vector,
return EncryptVectorized<int16_t>(CreateEINTtypeStruct(), input_vector,
size, state, result);
case LogicalTypeId::INTEGER:
return EncryptToEtype<int32_t>(CreateEINTtypeStruct(), input_vector,
return EncryptVectorized<int32_t>(CreateEINTtypeStruct(), input_vector,
size, state, result);
case LogicalTypeId::UINTEGER:
return EncryptToEtype<uint32_t>(CreateEINTtypeStruct(), input_vector,
return EncryptVectorized<uint32_t>(CreateEINTtypeStruct(), input_vector,
size, state, result);
case LogicalTypeId::BIGINT:
return EncryptToEtype<int64_t>(CreateEINTtypeStruct(), input_vector,
return EncryptVectorized<int64_t>(CreateEINTtypeStruct(), input_vector,
size, state, result);
case LogicalTypeId::UBIGINT:
return EncryptToEtype<uint64_t>(CreateEINTtypeStruct(), input_vector,
return EncryptVectorized<uint64_t>(CreateEINTtypeStruct(), input_vector,
size, state, result);
case LogicalTypeId::FLOAT:
return EncryptToEtype<float>(CreateEINTtypeStruct(), input_vector,
return EncryptVectorized<float>(CreateEINTtypeStruct(), input_vector,
size, state, result);
case LogicalTypeId::DOUBLE:
return EncryptToEtype<double>(CreateEINTtypeStruct(), input_vector,
return EncryptVectorized<double>(CreateEINTtypeStruct(), input_vector,
size, state, result);
default:
throw NotImplementedException("Unsupported numeric type for encryption");
}
} else if (vector_type.id() == LogicalTypeId::VARCHAR) {
return EncryptToEtype<string_t>(CreateEVARtypeStruct(), input_vector,
return EncryptVectorized<string_t>(CreateEVARtypeStruct(), input_vector,
size, state, result);
} else if (vector_type.IsNested()) {
throw NotImplementedException(
Expand Down Expand Up @@ -381,7 +396,7 @@ ScalarFunctionSet GetEncryptionStructFunction() {
LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT},
{"nonce_lo", LogicalType::UBIGINT},
{"value", type}}),
EncryptDataToEtype, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init));
EncryptDataVectorized, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init));
}

return set;
Expand Down

0 comments on commit 3b658a1

Please sign in to comment.