Skip to content

Commit

Permalink
added CTR and OCB algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
ccfelius committed Nov 5, 2024
1 parent 39ac907 commit 66385bd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
43 changes: 27 additions & 16 deletions src/core/crypto/crypto_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ void hex256(hash_bytes &in, hash_str &out) {
}
}

// it's nowhere defined so this is fine
const EVP_CIPHER *GetCipher(const string &key, AESStateSSL::Algorithm algorithm) {

switch(algorithm) {
Expand Down Expand Up @@ -75,15 +74,15 @@ const EVP_CIPHER *GetCipher(const string &key, AESStateSSL::Algorithm algorithm)
}
}

Check warning on line 75 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md)

'duckdb::GetCipher': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_extension.vcxproj]

Check warning on line 75 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md)

'duckdb::GetCipher': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_loadable_extension.vcxproj]

Check warning on line 75 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md)

'duckdb::GetCipher': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_extension.vcxproj]

Check warning on line 75 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64, x64-windows-static-md)

'duckdb::GetCipher': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_loadable_extension.vcxproj]

Check warning on line 75 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64_mingw, x64-mingw-static)

'duckdb::GetCipher': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_extension.vcxproj]

Check warning on line 75 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64_mingw, x64-mingw-static)

'duckdb::GetCipher': not all control paths return a value [D:\a\simple-encryption\simple-encryption\build\release\extension\simple_encryption\simple_encryption_loadable_extension.vcxproj]

Check warning on line 75 in src/core/crypto/crypto_primitives.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64_rtools, x64-mingw-static)

control reaches end of non-void function [-Wreturn-type]

AESStateSSL::AESStateSSL() : gcm_context(EVP_CIPHER_CTX_new()) {
if (!(gcm_context)) {
AESStateSSL::AESStateSSL() : context(EVP_CIPHER_CTX_new()) {
if (!(context)) {
throw InternalException("AES GCM failed with initializing context");
}
}

AESStateSSL::~AESStateSSL() {
// Clean up
EVP_CIPHER_CTX_free(gcm_context);
EVP_CIPHER_CTX_free(context);
}

bool AESStateSSL::IsOpenSSL() {
Expand Down Expand Up @@ -113,15 +112,15 @@ void AESStateSSL::InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const

mode = ENCRYPT;

if (1 != EVP_EncryptInit_ex(gcm_context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) {
if (1 != EVP_EncryptInit_ex(context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) {
throw InternalException("EncryptInit failed");
}
}

void AESStateSSL::InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, const string *key) {
mode = DECRYPT;

if (1 != EVP_DecryptInit_ex(gcm_context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) {
if (1 != EVP_DecryptInit_ex(context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) {
throw InternalException("DecryptInit failed");
}
}
Expand All @@ -130,14 +129,14 @@ size_t AESStateSSL::Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, i

switch (mode) {
case ENCRYPT:
if (1 != EVP_EncryptUpdate(gcm_context, data_ptr_cast(out), reinterpret_cast<int *>(&out_len),
if (1 != EVP_EncryptUpdate(context, data_ptr_cast(out), reinterpret_cast<int *>(&out_len),
const_data_ptr_cast(in), (int)in_len)) {
throw InternalException("Encryption failed at OpenSSL EVP_EncryptUpdate");
}
break;

case DECRYPT:
if (1 != EVP_DecryptUpdate(gcm_context, data_ptr_cast(out), reinterpret_cast<int *>(&out_len),
if (1 != EVP_DecryptUpdate(context, data_ptr_cast(out), reinterpret_cast<int *>(&out_len),
const_data_ptr_cast(in), (int)in_len)) {

throw InternalException("Decryption failed at OpenSSL EVP_DecryptUpdate");
Expand All @@ -156,23 +155,35 @@ size_t AESStateSSL::Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_
auto text_len = out_len;

switch (mode) {

case ENCRYPT:
if (1 != EVP_EncryptFinal_ex(gcm_context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len))) {
if (1 != EVP_EncryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len))) {
throw InternalException("EncryptFinal failed");
}
text_len += out_len;
// The computed tag is written at the end of a chunk
if (1 != EVP_CIPHER_CTX_ctrl(gcm_context, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) {

if (algorithm == CTR) {
return text_len;
}

// The computed tag is written at the end of a chunk for OCB and GCM
if (1 != EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_GET_TAG, tag_len,
tag)) {
throw InternalException("Calculating the tag failed");
}
return text_len;

case DECRYPT:
// Set expected tag value
if (!EVP_CIPHER_CTX_ctrl(gcm_context, EVP_CTRL_GCM_SET_TAG, tag_len, tag)) {
throw InternalException("Finalizing tag failed");

if (algorithm != CTR){
// Set expected tag value
if (!EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_SET_TAG, tag_len,
tag)) {
throw InternalException("Finalizing tag failed");
}
}

// EVP_DecryptFinal() will return an error code if final block is not correctly formatted.
int ret = EVP_DecryptFinal_ex(gcm_context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len));
int ret = EVP_DecryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len));
text_len += out_len;

if (ret > 0) {
Expand Down
34 changes: 21 additions & 13 deletions src/core/functions/scalar/encrypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#define TEST_KEY "0123456789112345"
#define MAX_BUFFER_SIZE 1024
#define MAX_BUFFER_SIZE_2 8096

#include "duckdb.hpp"
#include "duckdb/common/exception.hpp"
Expand Down Expand Up @@ -48,8 +49,8 @@ shared_ptr<EncryptionState> InitializeEncryption(ExpressionState &state) {
// For now, hardcode everything
// for some reason, this is 12
const string key = TEST_KEY;
unsigned char iv[12];
memcpy((void *)iv, "12345678901", 12);
unsigned char iv[16];
// memcpy((void *)iv, "12345678901", 16);
//
// // TODO; construct nonce based on immutable ROW_ID + hash(col_name)
// iv[12] = 0x00;
Expand All @@ -58,8 +59,8 @@ shared_ptr<EncryptionState> InitializeEncryption(ExpressionState &state) {
// iv[15] = 0x00;

auto encryption_state = InitializeCryptoState(state);
// encryption_state->GenerateRandomData(iv, 12);
encryption_state->InitializeEncryption(iv, 12, &key);
// encryption_state->GenerateRandomData(iv, 16);
encryption_state->InitializeEncryption(iv, 16, &key);

return encryption_state;
}
Expand All @@ -68,8 +69,8 @@ shared_ptr<EncryptionState> InitializeDecryption(ExpressionState &state) {

// For now, hardcode everything
const string key = TEST_KEY;
unsigned char iv[12];
memcpy((void *)iv, "12345678901", 12);
unsigned char iv[16];
memcpy((void *)iv, "12345678901", 16);

Check warning on line 73 in src/core/functions/scalar/encrypt.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / Windows (windows_amd64_rtools, x64-mingw-static)

'void* memcpy(void*, const void*, size_t)' reading 16 bytes from a region of size 12 [-Wstringop-overflow=]
//
// // TODO; construct nonce based on immutable ROW_ID + hash(col_name)
// iv[12] = 0x00;
Expand Down Expand Up @@ -122,31 +123,41 @@ static void EncryptData(DataChunk &args, ExpressionState &state,
Vector &result) {

auto &name_vector = args.data[0];
// auto encryption_state = InitializeEncryption(state);
auto encryption_state = InitializeEncryption(state);

// we do this here, we actually need to keep track of a pointer all the time
uint8_t encryption_buffer[MAX_BUFFER_SIZE];
uint8_t *buffer = encryption_buffer;
// uint8_t encryption_buffer[MAX_BUFFER_SIZE];
// uint8_t *buffer = encryption_buffer;

// TODO; handle all different input types
UnaryExecutor::Execute<string_t, string_t>(
name_vector, result, args.size(), [&](string_t name) {

// renew for each value
uint8_t encryption_buffer[MAX_BUFFER_SIZE];
uint8_t *buffer = encryption_buffer;
// For now; new encryption state for every new value
// does this has to do with multithreading or something?
auto encryption_state = InitializeEncryption(state);
auto size = name.GetSize();
//std::fill(encryption_buffer, encryption_buffer + size, 0);
// round the size to multiple of 16 for encryption efficiency
// size = (size + 15) & ~15;

unsigned char iv[16];
const string key = TEST_KEY;

encryption_state->GenerateRandomData(iv, 16);
encryption_state->InitializeEncryption(iv, 16, &key);

encryption_state->Process(
reinterpret_cast<const_data_ptr_t>(name.GetData()), size, buffer, size);

D_ASSERT(MAX_BUFFER_SIZE ==
sizeof(encryption_buffer) / sizeof(encryption_buffer[0]));

string_t encrypted_data(reinterpret_cast<const char *>(buffer), size);


auto printable_encrypted_data = Blob::ToString(encrypted_data);

D_ASSERT(CheckEncryption(printable_encrypted_data, buffer, size, reinterpret_cast<const_data_ptr_t>(name.GetData()), state) == 1);
Expand All @@ -155,9 +166,6 @@ static void EncryptData(DataChunk &args, ExpressionState &state,
// unsigned char tag[16];
// encryption_state->Finalize(buffer, 0, tag, 16);

// buffer pointer stays at the start haha
// buffer -= size;

return printable_encrypted_data;
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class DUCKDB_EXTENSION_API AESStateSSL : public duckdb::EncryptionState {

private:
bool ssl = true;
EVP_CIPHER_CTX *gcm_context;
EVP_CIPHER_CTX *context;
Mode mode;

// default value is GCM
Algorithm algorithm = GCM;
Algorithm algorithm = CTR;
};

} // namespace duckdb
Expand Down

0 comments on commit 66385bd

Please sign in to comment.