diff --git a/src/core/crypto/crypto_primitives.cpp b/src/core/crypto/crypto_primitives.cpp index 1c40209..534f576 100644 --- a/src/core/crypto/crypto_primitives.cpp +++ b/src/core/crypto/crypto_primitives.cpp @@ -33,17 +33,45 @@ void hex256(hash_bytes &in, hash_str &out) { } } -const EVP_CIPHER *GetCipher(const string &key) { - // For now, we only support GCM ciphers - switch (key.size()) { - case 16: - return EVP_aes_128_gcm(); - case 24: - return EVP_aes_192_gcm(); - case 32: - return EVP_aes_256_gcm(); - default: - throw InternalException("Invalid AES key length"); +// it's nowhere defined so this is fine +const EVP_CIPHER *GetCipher(const string &key, AESStateSSL::Algorithm algorithm) { + + switch(algorithm) { + case AESStateSSL::GCM: + switch (key.size()) { + case 16: + return EVP_aes_128_gcm(); + case 24: + return EVP_aes_192_gcm(); + case 32: + return EVP_aes_256_gcm(); + default: + throw InternalException("Invalid AES key length"); + } + + case AESStateSSL::CTR: + switch (key.size()) { + case 16: + return EVP_aes_128_ctr(); + case 24: + return EVP_aes_192_ctr(); + case 32: + return EVP_aes_256_ctr(); + default: + throw InternalException("Invalid AES key length"); + } + case AESStateSSL::OCB: + // For now, we only support GCM ciphers + switch (key.size()) { + case 16: + return EVP_aes_128_ocb(); + case 24: + return EVP_aes_192_ocb(); + case 32: + return EVP_aes_256_ocb(); + default: + throw InternalException("Invalid AES key length"); + } } } @@ -62,15 +90,30 @@ bool AESStateSSL::IsOpenSSL() { return ssl; } +void AESStateSSL::SetEncryptionAlgorithm(string_t s_algorithm) { + + if (s_algorithm == "GCM") { + algorithm = GCM; + } else if (s_algorithm == "CTR") { + algorithm = CTR; + } else if (s_algorithm == "OCB") { + algorithm = OCB; + } else { + throw InvalidInputException("Invalid encryption algorithm"); + } +} + void AESStateSSL::GenerateRandomData(data_ptr_t data, idx_t len) { // generate random bytes for nonce RAND_bytes(data, len); } void AESStateSSL::InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const string *key) { + // somewhere here or earlier we should set the encryption algorithm (maybe manually) + mode = ENCRYPT; - if (1 != EVP_EncryptInit_ex(gcm_context, GetCipher(*key), NULL, const_data_ptr_cast(key->data()), iv)) { + if (1 != EVP_EncryptInit_ex(gcm_context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) { throw InternalException("EncryptInit failed"); } } @@ -78,7 +121,7 @@ void AESStateSSL::InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const void AESStateSSL::InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, const string *key) { mode = DECRYPT; - if (1 != EVP_DecryptInit_ex(gcm_context, GetCipher(*key), NULL, const_data_ptr_cast(key->data()), iv)) { + if (1 != EVP_DecryptInit_ex(gcm_context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) { throw InternalException("DecryptInit failed"); } } diff --git a/src/include/simple_encryption/core/crypto/crypto_primitives.hpp b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp index 324d562..c2de1bf 100644 --- a/src/include/simple_encryption/core/crypto/crypto_primitives.hpp +++ b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp @@ -28,6 +28,9 @@ class DUCKDB_EXTENSION_API AESStateSSL : public duckdb::EncryptionState { explicit AESStateSSL(); ~AESStateSSL() override; + // We can use GCM, CTR or OCB + enum Algorithm { GCM, CTR, OCB }; + public: bool IsOpenSSL() override; void InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const std::string *key) override; @@ -36,10 +39,16 @@ class DUCKDB_EXTENSION_API AESStateSSL : public duckdb::EncryptionState { size_t Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) override; void GenerateRandomData(data_ptr_t data, idx_t len) override; + // crypto-specific functions + void SetEncryptionAlgorithm(string_t s_algorithm); + private: bool ssl = true; EVP_CIPHER_CTX *gcm_context; Mode mode; + + // default value is GCM + Algorithm algorithm = GCM; }; } // namespace duckdb