Skip to content

Commit

Permalink
vectorized encrypt wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ccfelius committed Jan 7, 2025
1 parent 5c7b707 commit c11defa
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 42 deletions.
4 changes: 2 additions & 2 deletions src/core/functions/scalar/encrypt_naive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ ProcessAndCastEncrypt(shared_ptr<EncryptionState> encryption_state,

// convert to Base64 into a newly allocated string in the result vector
T base64_data = StringVector::EmptyString(*result_vector, base64_size);
memset(base64_data.GetDataWriteable(), 0, 12);
base64_data.Finalize();
Blob::ToBase64(encrypted_data, base64_data.GetDataWriteable());

return base64_data;
Expand All @@ -89,7 +89,7 @@ ProcessEncrypt(shared_ptr<EncryptionState> encryption_state,

// convert to Base64 into a newly allocated string in the result vector
T base64_data = StringVector::EmptyString(*result_vector, base64_size);
memset(base64_data.GetDataWriteable(), 0, 12);
base64_data.Finalize();
Blob::ToBase64(encrypted_data, base64_data.GetDataWriteable());

return base64_data;
Expand Down
127 changes: 87 additions & 40 deletions src/core/functions/scalar/encrypt_vectorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "simple_encryption/core/functions/scalar/encrypt.hpp"
#include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp"

#define BATCH_SIZE 128

namespace simple_encryption {

namespace core {
Expand All @@ -40,7 +42,9 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V

// local, global and encryption state
auto &lstate = SimpleEncryptionFunctionLocalState::ResetAndGet(state);
auto simple_encryption_state = VCryptBasicFun::GetSimpleEncryptionState(state);
auto vcrypt_state =
VCryptBasicFun::GetSimpleEncryptionState(state);

auto encryption_state = VCryptBasicFun::GetEncryptionState(state);
auto key = VCryptBasicFun::GetKey(state);

Expand All @@ -53,73 +57,116 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V
auto &counter_vec = children[2];
auto &cipher_vec = children[3];

// result vector containing encrypted data
auto &blob = children[4];
// counter_vec->SetVectorType(VectorType::FLAT_VECTOR);
// cipher_vec->SetVectorType(VectorType::FLAT_VECTOR);

// set the constant vectors
nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR);
nonce_lo->SetVectorType(VectorType::CONSTANT_VECTOR);
UnifiedVectorFormat nonce_hi_u;
UnifiedVectorFormat nonce_lo_u;
UnifiedVectorFormat counter_vec_u;
UnifiedVectorFormat cipher_vec_u;

auto nonce_hi_64 = simple_encryption_state->iv[0];
auto nonce_lo_32 = simple_encryption_state->iv[0];
nonce_hi->ToUnifiedFormat(size, nonce_hi_u);
nonce_lo->ToUnifiedFormat(size, nonce_lo_u);
counter_vec->ToUnifiedFormat(size, counter_vec_u);
cipher_vec->ToUnifiedFormat(size, cipher_vec_u);

// is not the pointer but really the actual value copied?
// Set constant vectors to a single value
nonce_hi->Reference(Value::UBIGINT(nonce_hi_64));
nonce_hi->Reference(Value::UBIGINT(nonce_lo_32));
auto nonce_hi_data = FlatVector::GetData<uint64_t>(*nonce_hi);
auto nonce_lo_data = FlatVector::GetData<uint32_t>(*nonce_lo);
auto counter_vec_data = FlatVector::GetData<uint32_t>(*counter_vec);
auto cipher_vec_data = FlatVector::GetData<uint8_t>(*cipher_vec);

counter_vec->SetVectorType(VectorType::FLAT_VECTOR);
cipher_vec->SetVectorType(VectorType::FLAT_VECTOR);
// set the nonces directly
nonce_hi_data[0] = vcrypt_state->iv[0];
nonce_lo_data[0] = vcrypt_state->iv[1];

// Set the blob vector to dict vector for compressed execution
blob->SetVectorType(VectorType::DICTIONARY_VECTOR);
nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR);
nonce_lo->SetVectorType(VectorType::CONSTANT_VECTOR);

encryption_state->InitializeEncryption(reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16, key);
// result vector containing encrypted data
auto &blob = children[4];
SelectionVector sel(size);
blob->Slice(*blob, sel, size);

auto &blob_child = DictionaryVector::Child(*blob);
auto &blob_sel = DictionaryVector::SelVector(*blob);
blob_sel.Initialize(size);
auto &blob_child = DictionaryVector::Child(*blob);
auto blob_child_data = FlatVector::GetData<string_t>(blob_child);

encryption_state->InitializeEncryption(
reinterpret_cast<const_data_ptr_t>(vcrypt_state->iv), 16, key);

// we process in batches of 128 values, or we can do it with all and cut at each 128 * sizeof(T) bits (only works for similar lengths)
// we process in batches of 128 values
// or we can do it with all and cut at each 128 * sizeof(T) bits (only works for similar lengths)
// fill 512 bytes, so 512 / sizeof(T) values and at least 128 values.
// note: this only works for fixed-size types
auto batch_size = 128 * sizeof(T);
auto to_process = size;
auto total_size = sizeof(T) * size;
// and the cipher
uint32_t batch_size;
if (to_process > BATCH_SIZE) {
batch_size = BATCH_SIZE;
} else {
batch_size = to_process;
}
auto batch_size_in_bytes = BATCH_SIZE * sizeof(T);
D_ASSERT(batch_size_in_bytes = 512);

// todo: assign buffer_p with the right size
encryption_state->Process(reinterpret_cast<const unsigned char*>(input_vector), total_size, lstate.buffer_p, total_size);
encryption_state->Process(
reinterpret_cast<const unsigned char *>(input_vector), total_size,
lstate.buffer_p, total_size);

auto index = 0;
auto batch_nr = 0;
// get counter from local state
uint32_t counter = 0;
uint8_t cipher = 0;
const size_t step = sizeof(T) / 16;
uint32_t step = sizeof(T) / 16;
uint64_t buffer_offset;

// TODO: for strings this all works slighly different
for(int i = 0; i + 128; i < (DEFAULT_STANDARD_VECTOR_SIZE / 128)){
// TODO: for strings this all works slightly different because the size is variable
while (to_process) {
buffer_offset = batch_nr * batch_size_in_bytes;
blob_child_data[batch_nr] =
StringVector::EmptyString(blob_child, batch_size_in_bytes);
*(uint32_t *)blob_child_data[batch_nr].GetPrefixWriteable() =
*(uint32_t *)lstate.buffer_p + buffer_offset;
memcpy(blob_child_data[batch_nr].GetDataWriteable(), lstate.buffer_p,
batch_size);

buffer_offset = batch_nr * sizeof(T) * 128;
// Allocate space in the dictionary vector (i.e. blob_child)
string_t batch_data = StringVector::EmptyString(blob_child, batch_size); // value size
*(uint32_t*) batch_data.GetPrefixWriteable() = *(uint32_t *) lstate.buffer_p + buffer_offset;
memcpy(batch_data.GetDataWriteable(), lstate.buffer_p, batch_size);
blob_child_data[batch_nr].Finalize();

// set index in selection vector
for (int j = 0; j++; j < 128){
cipher = j % step;
counter += (cipher == 0 && index != 0) ? 1 : 0;
cipher_vec->SetValue(index, Value::TINYINT(cipher));
counter_vec->SetValue(index, Value::UINTEGER(counter));
blob_sel.set_index(index, batch_nr);
index++;
for (uint32_t j = 0; j < batch_size; j++) {

// do cipher + counter in 1. U32T
cipher = index % 128;
counter = index;
// cipher = j % step;
// counter += (cipher == 0 && index != 0) ? 1 : 0;
// todo; also fix this to blob_sel_data?
blob_sel.set_index(index, batch_nr);
cipher_vec_data[index] = batch_nr;
counter_vec_data[index] = counter;
index++;
}

batch_nr++;

// todo: optimize
if (to_process > BATCH_SIZE) {
to_process -= BATCH_SIZE;
} else {
// processing finalized
to_process = 0;
break;
}

if (to_process < BATCH_SIZE) {
batch_size = to_process;
batch_size_in_bytes = to_process * sizeof(T);
}
}
}


template <typename T>
void DecryptFromEtype(Vector &input_vector, uint64_t size,
ExpressionState &state, Vector &result) {
Expand Down Expand Up @@ -267,7 +314,7 @@ ScalarFunctionSet GetEncryptionVectorizedFunction() {
LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT},
{"nonce_lo", LogicalType::UBIGINT},
{"counter", LogicalType::UINTEGER},
{"cipher", LogicalType::UINTEGER},
{"cipher", LogicalType::TINYINT},
{"value", LogicalType::BLOB}}),
EncryptDataVectorized, EncryptFunctionData::EncryptBind, nullptr, nullptr, SimpleEncryptionFunctionLocalState::Init));
}
Expand Down

0 comments on commit c11defa

Please sign in to comment.