diff --git a/.gitignore b/.gitignore index a5b17a57..05413b30 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,6 @@ /cipherstash-proxy.local.toml mise.local.toml tests/pg/data** -tests/sql/cipherstash-encrypt.sql -tests/sql/cipherstash-encrypt-uninstall.sql .vscode rust-toolchain.toml @@ -13,8 +11,9 @@ rust-toolchain.toml # release artifacts /cipherstash-proxy -/cipherstash-eql.sql /packages/cipherstash-proxy/eql-version-at-build-time.txt +/cipherstash-encrypt.sql +/cipherstash-encrypt-uninstall.sql # credentials for local dev .env.proxy.docker diff --git a/Cargo.lock b/Cargo.lock index b2e0ee88..3c322604 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -181,6 +181,7 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" dependencies = [ + "serde", "zeroize", ] @@ -287,6 +288,61 @@ dependencies = [ "fs_extra", ] +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -317,6 +373,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" +[[package]] +name = "base32" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "022dfe9eb35f19ebbcb51e0b40a5ab759f46ad60cadf7297e0bd085afb50e076" + [[package]] name = "base64" version = "0.22.1" @@ -588,9 +650,9 @@ dependencies = [ [[package]] name = "cipherstash-client" -version = "0.18.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f099b1db6cf37b0ca36e9c8e0c2dade20f2035804e225f52475d44e750dd5dd5" +checksum = "8fe21509165da6daf50b84d4dc9bc46b558e5afb34db75dbd2371b963faabe4d" dependencies = [ "aes-gcm-siv", "anyhow", @@ -605,6 +667,7 @@ dependencies = [ "cipherstash-config", "cipherstash-core", "cllw-ore", + "cts-common", "derive_more", "dirs", "futures", @@ -724,9 +787,12 @@ name = "cipherstash-proxy-integration" version = "0.1.0" dependencies = [ "chrono", + "cipherstash-client", + "cipherstash-config", "cipherstash-proxy", "clap", "fake 4.2.0", + "hex", "rand 0.9.0", "recipher", "rustls", @@ -739,6 +805,7 @@ dependencies = [ "tokio-rustls", "tracing", "tracing-subscriber", + "uuid", "webpki-roots", ] @@ -953,6 +1020,27 @@ dependencies = [ "cipher 0.4.4", ] +[[package]] +name = "cts-common" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "938da7d14d05c2769bf7ae33c5a395eb6a34ffdd25ec286e97702ae563314f9b" +dependencies = [ + "arrayvec", + "axum", + "base32", + "diesel", + "fake 3.1.0", + "http", + "miette", + "rand 0.8.5", + "regex", + "serde", + "thiserror 1.0.69", + "url", + "vitaminc", +] + [[package]] name = "darling" version = "0.20.10" @@ -973,6 +1061,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", + "strsim", "syn 2.0.100", ] @@ -1075,6 +1164,39 @@ version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc55fe0d1f6c107595572ec8b107c0999bb1a2e0b75e37429a4fb0d6474a0e7d" +[[package]] +name = "diesel" +version = "2.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff3e1edb1f37b4953dd5176916347289ed43d7119cc2e6c7c3f7849ff44ea506" +dependencies = [ + "chrono", + "diesel_derives", + "uuid", +] + +[[package]] +name = "diesel_derives" +version = "2.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68d4216021b3ea446fd2047f5c8f8fe6e98af34508a254a01e4d6bc1e844f84d" +dependencies = [ + "diesel_table_macro_syntax", + "dsl_auto_type", + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "diesel_table_macro_syntax" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25" +dependencies = [ + "syn 2.0.100", +] + [[package]] name = "diff" version = "0.1.13" @@ -1123,6 +1245,20 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "dsl_auto_type" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139ae9aca7527f85f26dd76483eb38533fd84bd571065da1739656ef71c5ff5b" +dependencies = [ + "darling", + "either", + "heck", + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "dummy" version = "0.8.0" @@ -1135,6 +1271,18 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "dummy" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abcba80bdf851db5616e27ff869399468e2d339d7c6480f5887681e6bdfc2186" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "dummy" version = "0.11.0" @@ -1172,6 +1320,7 @@ dependencies = [ "thiserror 2.0.12", "tracing", "tracing-subscriber", + "vec1", ] [[package]] @@ -1223,12 +1372,24 @@ dependencies = [ "uuid", ] +[[package]] +name = "fake" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef603df4ba9adbca6a332db7da6f614f21eafefbaf8e087844e452fdec152d0" +dependencies = [ + "deunicode", + "dummy 0.9.2", + "rand 0.8.5", +] + [[package]] name = "fake" version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b591050272097cc85b2f3c1cc4817ba4560057d10fcae6f7339f1cf622da0a0f" dependencies = [ + "chrono", "deunicode", "dummy 0.11.0", "rand 0.9.0", @@ -2034,6 +2195,12 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "md-5" version = "0.10.6" @@ -3328,6 +3495,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.8" @@ -3905,6 +4082,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -4134,6 +4312,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "vec1" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eab68b56840f69efb0fefbe3ab6661499217ffdc58e2eef7c3f6f69835386322" + [[package]] name = "version_check" version = "0.9.5" diff --git a/docs/errors.md b/docs/errors.md index 7e5faceb..1105d064 100644 --- a/docs/errors.md +++ b/docs/errors.md @@ -314,7 +314,7 @@ For example: ## Unknown Column -The column has an encrypted type (PostgreSQL `cs_encrypted_v1` type ) with no encryption configuration. +The column has an encrypted type (PostgreSQL `eql_v1_encrypted` type ) with no encryption configuration. Without the configuration, Cipherstash Proxy does not know how to encrypt the column. Any data is unprotected and unencrypted. @@ -341,7 +341,7 @@ Column 'column_name' in table 'table_name' has no Encrypt configuration ## Unknown Table -The table has one or more encrypted columns (PostgreSQL `cs_encrypted_v1` type ) with no encryption configuration. +The table has one or more encrypted columns (PostgreSQL `eql_v1_encrypted` type ) with no encryption configuration. Without the configuration, Cipherstash Proxy does not know how to encrypt the column. Any data is unprotected and unencrypted. diff --git a/docs/getting-started/schema-example.sql b/docs/getting-started/schema-example.sql index 29e3e743..0120cde4 100644 --- a/docs/getting-started/schema-example.sql +++ b/docs/getting-started/schema-example.sql @@ -1,12 +1,12 @@ -TRUNCATE TABLE cs_configuration_v1; +TRUNCATE TABLE public.eql_v1_configuration; -- Exciting cipherstash table DROP TABLE IF EXISTS users; CREATE TABLE users ( id SERIAL PRIMARY KEY, - encrypted_email cs_encrypted_v1, - encrypted_dob cs_encrypted_v1, - encrypted_salary cs_encrypted_v1 + encrypted_email eql_v1_encrypted, + encrypted_dob eql_v1_encrypted, + encrypted_salary eql_v1_encrypted ); SELECT cs_add_index_v1( diff --git a/docs/how-to.md b/docs/how-to.md index a38f80cb..5f906a84 100644 --- a/docs/how-to.md +++ b/docs/how-to.md @@ -153,7 +153,7 @@ You can also install EQL by running [the installation script](https://github.com Once you have installed EQL, you can see what version is installed by querying the database: ```sql -SELECT cs_eql_version(); +SELECT eql_v1.version(); ``` This will output the version of EQL installed. @@ -162,22 +162,22 @@ This will output the version of EQL installed. In your existing PostgreSQL database, you store your data in tables and columns. Those columns have types like `integer`, `text`, `timestamp`, and `boolean`. -When storing encrypted data in PostgreSQL with Proxy, you use a special column type called `cs_encrypted_v1`, which is [provided by EQL](#setting-up-the-database-schema). -`cs_encrypted_v1` is a container column type that can be used for any type of encrypted data you want to store or search, whether they are numbers (`int`, `small_int`, `big_int`), text (`text`), dates and times (`date`), or booleans (`boolean`). +When storing encrypted data in PostgreSQL with Proxy, you use a special column type called `eql_v1_encrypted`, which is [provided by EQL](#setting-up-the-database-schema). +`eql_v1_encrypted` is a container column type that can be used for any type of encrypted data you want to store or search, whether they are numbers (`int`, `small_int`, `big_int`), text (`text`), dates and times (`date`), or booleans (`boolean`). Create a table with an encrypted column for `email`: ```sql CREATE TABLE users ( id SERIAL PRIMARY KEY, - email cs_encrypted_v1 + email eql_v1_encrypted ) ``` This creates a `users` table with two columns: - `id`, an autoincrementing integer column that is the primary key for the record - - `email`, a `cs_encrypted_v1` column + - `email`, a `eql_v1_encrypted` column There are important differences between the plaintext columns you've traditionally used in PostgreSQL and encrypted columns with CipherStash Proxy: diff --git a/mise.toml b/mise.toml index 9cf04e2a..64e563df 100644 --- a/mise.toml +++ b/mise.toml @@ -409,27 +409,28 @@ fi """ [tasks."postgres:setup"] +depends = ["postgres:eql:teardown"] alias = 's' description = "Installs EQL and applies schema to database" run = """ #!/bin/bash cd tests mise run postgres:fail_if_not_running -mise run postgres:eql:download -cat sql/cipherstash-encrypt.sql | docker exec -i postgres${CONTAINER_SUFFIX} psql postgresql://${CS_DATABASE__USERNAME}:${CS_DATABASE__PASSWORD}@${CS_DATABASE__HOST}:${CS_DATABASE__PORT}/${CS_DATABASE__NAME} -f- +cat sql/schema-uninstall.sql | docker exec -i postgres${CONTAINER_SUFFIX} psql postgresql://${CS_DATABASE__USERNAME}:${CS_DATABASE__PASSWORD}@${CS_DATABASE__HOST}:${CS_DATABASE__PORT}/${CS_DATABASE__NAME} -f- +cat ../cipherstash-encrypt-uninstall.sql | docker exec -i postgres${CONTAINER_SUFFIX} psql postgresql://${CS_DATABASE__USERNAME}:${CS_DATABASE__PASSWORD}@${CS_DATABASE__HOST}:${CS_DATABASE__PORT}/${CS_DATABASE__NAME} -f- +cat ../cipherstash-encrypt.sql | docker exec -i postgres${CONTAINER_SUFFIX} psql postgresql://${CS_DATABASE__USERNAME}:${CS_DATABASE__PASSWORD}@${CS_DATABASE__HOST}:${CS_DATABASE__PORT}/${CS_DATABASE__NAME} -f- cat sql/schema.sql | docker exec -i postgres${CONTAINER_SUFFIX} psql postgresql://${CS_DATABASE__USERNAME}:${CS_DATABASE__PASSWORD}@${CS_DATABASE__HOST}:${CS_DATABASE__PORT}/${CS_DATABASE__NAME} -f- """ [tasks."postgres:eql:teardown"] -alias = 's' +depends = ["eql:download"] description = "Uninstalls EQL and removes schema from database" run = """ #!/bin/bash cd tests mise run postgres:fail_if_not_running -mise run postgres:eql:download cat sql/schema-uninstall.sql | docker exec -i postgres${CONTAINER_SUFFIX} psql postgresql://${CS_DATABASE__USERNAME}:${CS_DATABASE__PASSWORD}@${CS_DATABASE__HOST}:${CS_DATABASE__PORT}/${CS_DATABASE__NAME} -f- -cat sql/cipherstash-encrypt-uninstall.sql | docker exec -i postgres${CONTAINER_SUFFIX} psql postgresql://${CS_DATABASE__USERNAME}:${CS_DATABASE__PASSWORD}@${CS_DATABASE__HOST}:${CS_DATABASE__PORT}/${CS_DATABASE__NAME} -f- +cat ../cipherstash-encrypt-uninstall.sql | docker exec -i postgres${CONTAINER_SUFFIX} psql postgresql://${CS_DATABASE__USERNAME}:${CS_DATABASE__PASSWORD}@${CS_DATABASE__HOST}:${CS_DATABASE__PORT}/${CS_DATABASE__NAME} -f- """ [tasks."postgres:up"] @@ -490,34 +491,32 @@ for d in tests/pg/data-*; do done """ - -[tasks."postgres:eql:download"] +[tasks."eql:download"] alias = 'e' -description = "Download latest EQL release" +description = "Download latest EQL release or use local copy" dir = "{{config_root}}/tests" outputs = [ - "{{config_root}}/tests/sql/cipherstash-encrypt.sql", - "{{config_root}}/tests/sql/cipherstash-encrypt-uninstall.sql", + "{{config_root}}/cipherstash-encrypt.sql", + "{{config_root}}/cipherstash-encrypt-uninstall.sql", ] run = """ # install script if [ -z "$CS_EQL_PATH" ]; then - curl -sLo sql/cipherstash-encrypt.sql https://github.com/cipherstash/encrypt-query-language/releases/download/${CS_EQL_VERSION}/cipherstash-encrypt.sql + curl -sLo "{{config_root}}/cipherstash-encrypt.sql" https://github.com/cipherstash/encrypt-query-language/releases/download/${CS_EQL_VERSION}/cipherstash-encrypt.sql else - echo "Using EQL: ${CS_EQL_PATH}" - cp "$CS_EQL_PATH" sql/cipherstash-encrypt.sql + echo "Using EQL: ${CS_EQL_PATH}/cipherstash-encrypt.sql" + cp "$CS_EQL_PATH/cipherstash-encrypt.sql" "{{config_root}}/cipherstash-encrypt.sql" fi # uninstall script -if [ -z "$CS_EQL_UNINSTALL_PATH" ]; then - curl -sLo sql/cipherstash-encrypt-uninstall.sql https://github.com/cipherstash/encrypt-query-language/releases/download/${CS_EQL_VERSION}/cipherstash-encrypt-uninstall.sql +if [ -z "$CS_EQL_PATH" ]; then + curl -sLo "{{config_root}}/cipherstash-encrypt-uninstall.sql" https://github.com/cipherstash/encrypt-query-language/releases/download/${CS_EQL_VERSION}/cipherstash-encrypt-uninstall.sql else - echo "Using EQL: ${CS_EQL_PATH}" - cp "$CS_EQL_UNINSTALL_PATH" sql/cipherstash-encrypt-uninstall.sql + echo "Using EQL: ${CS_EQL_PATH}/cipherstash-encrypt-uninstall.sql" + cp "$CS_EQL_PATH/cipherstash-encrypt-uninstall.sql" "{{config_root}}/cipherstash-encrypt-uninstall.sql" fi """ - [tasks."python:test"] dir = "{{config_root}}/tests" description = "Runs python tests" @@ -567,7 +566,7 @@ cp -v {{config_root}}/target/{{ target }}/release/cipherstash-proxy {{config_roo """ [tasks."build:docker"] -depends = ["build:docker:fetch_eql"] +depends = ["eql:download"] description = "Build a Docker image for cipherstash-proxy" run = """ {% set default_platform = "linux/" ~ arch() | replace(from="x86_64", to="amd64") %} diff --git a/packages/cipherstash-proxy-integration/Cargo.toml b/packages/cipherstash-proxy-integration/Cargo.toml index 546e1ff8..74ea4334 100644 --- a/packages/cipherstash-proxy-integration/Cargo.toml +++ b/packages/cipherstash-proxy-integration/Cargo.toml @@ -24,5 +24,9 @@ tracing-subscriber = { workspace = true } webpki-roots = "0.26.7" [dev-dependencies] +cipherstash-client = { version = "0.20.0", features = ["tokio"] } +cipherstash-config = "0.2.3" clap = "4.5.32" -fake = { version = "4", features = ["derive"] } +fake = { version = "4", features = ["chrono", "derive"] } +hex = "0.4.3" +uuid = { version = "1.11.0", features = ["serde", "v4"] } diff --git a/packages/cipherstash-proxy-integration/src/extended_protocol_error_messages.rs b/packages/cipherstash-proxy-integration/src/extended_protocol_error_messages.rs index c45ec4af..37ac99ea 100644 --- a/packages/cipherstash-proxy-integration/src/extended_protocol_error_messages.rs +++ b/packages/cipherstash-proxy-integration/src/extended_protocol_error_messages.rs @@ -67,10 +67,10 @@ mod tests { let msg = err.to_string(); // This is similar to below. The error message comes from tokio-postgres when Proxy - // returns cs_encrypted_v1 and the client cannot convert to a string. + // returns eql_v1_encrypted and the client cannot convert to a string. // If mapping errors are enabled (enable_mapping_errors or CS_DEVELOPMENT__ENABLE_MAPPING_ERRORS), // then Proxy will return an error that says "Column X in table Y has no Encrypt configuration" - assert_eq!(msg, "error serializing parameter 1: cannot convert between the Rust type `&str` and the Postgres type `cs_encrypted_v1`"); + assert_eq!(msg, "error serializing parameter 1: cannot convert between the Rust type `&str` and the Postgres type `eql_v1_encrypted`"); } else { unreachable!(); } diff --git a/packages/cipherstash-proxy-integration/src/generate.rs b/packages/cipherstash-proxy-integration/src/generate.rs new file mode 100644 index 00000000..13d11820 --- /dev/null +++ b/packages/cipherstash-proxy-integration/src/generate.rs @@ -0,0 +1,259 @@ +#[cfg(test)] +mod tests { + use crate::common::trace; + use cipherstash_client::config::EnvSource; + use cipherstash_client::credentials::auto_refresh::AutoRefresh; + use cipherstash_client::encryption::{ + Encrypted, EncryptedSteVecTerm, JsonIndexer, JsonIndexerOptions, OreTerm, Plaintext, + PlaintextTarget, ReferencedPendingPipeline, + }; + use cipherstash_client::{encryption::ScopedCipher, zerokms::EncryptedRecord}; + use cipherstash_client::{ConsoleConfig, CtsConfig, ZeroKMSConfig}; + use cipherstash_config::column::{Index, IndexType}; + use cipherstash_config::{ColumnConfig, ColumnType}; + use cipherstash_proxy::Identifier; + use serde::{Deserialize, Serialize}; + use std::sync::Arc; + use tracing::info; + use uuid::Uuid; + + pub mod option_mp_base85 { + use cipherstash_client::zerokms::encrypted_record::formats::mp_base85; + use cipherstash_client::zerokms::EncryptedRecord; + use serde::{Deserialize, Deserializer, Serializer}; + + pub fn serialize( + value: &Option, + serializer: S, + ) -> Result + where + S: Serializer, + { + match value { + Some(record) => mp_base85::serialize(record, serializer), + None => serializer.serialize_none(), + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let result = Option::::deserialize(deserializer)?; + Ok(result) + } + } + + #[derive(Debug, Deserialize, Serialize)] + pub struct EqlEncrypted { + #[serde(rename = "c", with = "option_mp_base85")] + ciphertext: Option, + #[serde(rename = "i")] + identifier: Identifier, + #[serde(rename = "v")] + version: u16, + + #[serde(rename = "o")] + ore_index: Option>, + #[serde(rename = "m")] + match_index: Option>, + #[serde(rename = "u")] + unique_index: Option, + + #[serde(rename = "s")] + selector: Option, + + #[serde(rename = "b")] + blake3_index: Option, + + #[serde(rename = "ocf")] + ore_cclw_fixed_index: Option, + #[serde(rename = "ocv")] + ore_cclw_var_index: Option, + + #[serde(rename = "sv")] + ste_vec_index: Option>, + } + + #[derive(Debug, Deserialize, Serialize)] + pub struct EqlSteVecEncrypted { + #[serde(rename = "c", with = "option_mp_base85")] + ciphertext: Option, + + #[serde(rename = "s")] + selector: Option, + #[serde(rename = "b")] + blake3_index: Option, + #[serde(rename = "ocf")] + ore_cclw_fixed_index: Option, + #[serde(rename = "ocv")] + ore_cclw_var_index: Option, + } + + impl EqlEncrypted { + pub fn ste_vec(ste_vec_index: Vec) -> Self { + Self { + ste_vec_index: Some(ste_vec_index), + ciphertext: None, + identifier: Identifier { + table: "blah".to_string(), + column: "vtha".to_string(), + }, + version: 1, + ore_index: None, + match_index: None, + unique_index: None, + selector: None, + ore_cclw_fixed_index: None, + ore_cclw_var_index: None, + blake3_index: None, + } + } + } + impl EqlSteVecEncrypted { + pub fn ste_vec_element(selector: String, record: EncryptedRecord) -> Self { + Self { + ciphertext: Some(record), + selector: Some(selector), + ore_cclw_fixed_index: None, + ore_cclw_var_index: None, + blake3_index: None, + } + } + } + + #[tokio::test] + async fn generate_ste_vec() { + trace(); + + // clear().await; + // let client = connect_with_tls(PROXY).await; + + let console_config = ConsoleConfig::builder().with_env().build().unwrap(); + let cts_config = CtsConfig::builder().with_env().build().unwrap(); + let zerokms_config = ZeroKMSConfig::builder() + .add_source(EnvSource::default()) + .console_config(&console_config) + .cts_config(&cts_config) + .build_with_client_key() + .unwrap(); + let zerokms_client = zerokms_config + .create_client_with_credentials(AutoRefresh::new(zerokms_config.credentials())); + + let dataset_id = Uuid::parse_str("295504329cb045c398dc464c52a287a1").unwrap(); + + let cipher = Arc::new( + ScopedCipher::init(Arc::new(zerokms_client), Some(dataset_id)) + .await + .unwrap(), + ); + + let prefix = "prefix".to_string(); + + let column_config = ColumnConfig::build("column_name".to_string()) + .casts_as(ColumnType::JsonB) + .add_index(Index::new(IndexType::SteVec { + prefix: prefix.to_owned(), + })); + + // let mut value = + // serde_json::from_str::("{\"hello\": \"one\", \"n\": 10}").unwrap(); + + // let mut value = + // serde_json::from_str::("{\"hello\": \"two\", \"n\": 20}").unwrap(); + + let value = + serde_json::from_str::("{\"hello\": \"two\", \"n\": 30}").unwrap(); + + // let mut value = + // serde_json::from_str::("{\"hello\": \"world\", \"n\": 42}").unwrap(); + + // let mut value = + // serde_json::from_str::("{\"hello\": \"world\", \"n\": 42}").unwrap(); + + // let mut value = + // serde_json::from_str::("{\"blah\": { \"vtha\": 42 }}").unwrap(); + + let plaintext = Plaintext::JsonB(Some(value)); + + let idx = 0; + + let mut pipeline = ReferencedPendingPipeline::new(cipher.clone()); + let encryptable = PlaintextTarget::new(plaintext, column_config); + pipeline + .add_with_ref::(encryptable, idx) + .unwrap(); + + let mut encrypteds = vec![]; + + let mut result = pipeline.encrypt(None).await.unwrap(); + if let Some(Encrypted::SteVec(ste_vec)) = result.remove(idx) { + for entry in ste_vec { + let selector = hex::encode(entry.0.as_bytes()); + let term = entry.1; + let record = entry.2; + + let mut e = EqlSteVecEncrypted::ste_vec_element(selector, record); + + match term { + EncryptedSteVecTerm::Mac(items) => { + e.blake3_index = Some(hex::encode(&items)); + } + EncryptedSteVecTerm::OreFixed(o) => { + e.ore_cclw_fixed_index = Some(hex::encode(&o)); + } + EncryptedSteVecTerm::OreVariable(o) => { + e.ore_cclw_var_index = Some(hex::encode(&o)); + } + } + + encrypteds.push(e); + } + // info!("{:?}" = encrypteds); + } + + info!("---------------------------------------------"); + + let e = EqlEncrypted::ste_vec(encrypteds); + info!("{:?}" = ?e); + + let json = serde_json::to_value(e).unwrap(); + info!("{}", json); + + let indexer = JsonIndexer::new(JsonIndexerOptions { prefix }); + + info!("---------------------------------------------"); + + // Path + // let path: String = "$.blah.vtha".to_string(); + // let selector = Selector::parse(&path).unwrap(); + // let selector = indexer.generate_selector(selector, cipher.index_key()); + // let selector = hex::encode(selector.0); + // info!("{}", selector); + + // Comparison + let n = 30; + let term = OreTerm::Number(n); + + let term = indexer.generate_term(term, cipher.index_key()).unwrap(); + + match term { + EncryptedSteVecTerm::Mac(_) => todo!(), + EncryptedSteVecTerm::OreFixed(ore_cllw8_v1) => { + let term = hex::encode(ore_cllw8_v1.bytes); + info!("{n}: {term}"); + } + EncryptedSteVecTerm::OreVariable(_) => todo!(), + } + + // if let Some(ste_vec_index) = e.ste_vec_index { + // for e in ste_vec_index { + // info!("{}", e); + // if let Some(ct) = e.ciphertext { + // let decrypted = cipher.decrypt(encrypted).await?; + // info!("{}", decrypted); + // } + // } + // } + } +} diff --git a/packages/cipherstash-proxy/Cargo.toml b/packages/cipherstash-proxy/Cargo.toml index d63a2513..7ee18238 100644 --- a/packages/cipherstash-proxy/Cargo.toml +++ b/packages/cipherstash-proxy/Cargo.toml @@ -8,7 +8,7 @@ bigdecimal = { version = "0.4.6", features = ["serde-json"] } arc-swap = "1.7.1" bytes = { version = "1.9", default-features = false } chrono = { version = "0.4.39", features = ["clock"] } -cipherstash-client = { version = "0.18.0-pre.1", features = ["tokio"] } +cipherstash-client = { version = "0.20.0", features = ["tokio"] } cipherstash-config = "0.2.3" clap = { version = "4.5.31", features = ["derive", "env"] } config = { version = "0.15", features = [ diff --git a/packages/cipherstash-proxy/src/encrypt/config/manager.rs b/packages/cipherstash-proxy/src/encrypt/config/manager.rs index 31312533..df23d8ab 100644 --- a/packages/cipherstash-proxy/src/encrypt/config/manager.rs +++ b/packages/cipherstash-proxy/src/encrypt/config/manager.rs @@ -195,8 +195,7 @@ pub async fn load_encrypt_config(config: &DatabaseConfig) -> Result bool { let msg = e.to_string(); - msg.contains("cs_configuration_v1") && msg.contains("does not exist") + msg.contains("eql_v1_configuration") && msg.contains("does not exist") } diff --git a/packages/cipherstash-proxy/src/encrypt/mod.rs b/packages/cipherstash-proxy/src/encrypt/mod.rs index faac044e..fca8efd1 100644 --- a/packages/cipherstash-proxy/src/encrypt/mod.rs +++ b/packages/cipherstash-proxy/src/encrypt/mod.rs @@ -3,7 +3,8 @@ mod schema; use crate::{ config::TandemConfig, - connect, eql, + connect, + eql::{self, EqlEncryptedBody, EqlEncryptedIndexes}, error::{EncryptError, Error}, log::ENCRYPT, postgresql::Column, @@ -13,10 +14,9 @@ use cipherstash_client::{ config::EnvSource, credentials::{auto_refresh::AutoRefresh, ServiceCredentials}, encryption::{ - self, Encrypted, EncryptionError, IndexTerm, Plaintext, PlaintextTarget, - ReferencedPendingPipeline, + self, Encrypted, EncryptedEntry, EncryptedSteVecTerm, IndexTerm, Plaintext, + PlaintextTarget, ReferencedPendingPipeline, }, - zerokms::EncryptedRecord, ConsoleConfig, CtsConfig, ZeroKMSConfig, }; use cipherstash_config::ColumnConfig; @@ -57,10 +57,13 @@ impl Encrypt { let eql_version = { let client = connect::database(&config.database).await?; - let rows = client.query("SELECT cs_eql_version();", &[]).await; + let rows = client + .query("SELECT eql_v1.version() AS version;", &[]) + .await; + // let rows = client.query("SELECT 'WAT' AS version;", &[]).await; match rows { - Ok(rows) => rows.first().map(|row| row.get("cs_eql_version")), + Ok(rows) => rows.first().map(|row| row.get("version")), Err(err) => { warn!( msg = "Could not query EQL version from database", @@ -88,7 +91,7 @@ impl Encrypt { &self, plaintexts: Vec>, columns: &[Option], - ) -> Result>, Error> { + ) -> Result>, Error> { let mut pipeline = ReferencedPendingPipeline::new(self.cipher.clone()); for (idx, item) in plaintexts.into_iter().zip(columns.iter()).enumerate() { @@ -141,22 +144,17 @@ impl Encrypt { /// pub async fn decrypt( &self, - ciphertexts: Vec>, + ciphertexts: Vec>, ) -> Result>, Error> { // Create a mutable vector to hold the decrypted results let mut results = vec![None; ciphertexts.len()]; // Collect the index and ciphertext details for every Some(ciphertext) - let (indices, encrypted) = ciphertexts + let (indices, encrypted): (Vec<_>, Vec<_>) = ciphertexts .into_iter() .enumerate() - .filter_map(|(idx, opt)| { - opt.map(|ct| { - eql_encrypted_to_encrypted_record(ct) - .map(|encrypted_record| (idx, encrypted_record)) - }) - }) - .collect::, Vec<_>), _>>()?; + .filter_map(|(idx, eql)| Some((idx, eql?.body.ciphertext))) + .collect::<_>(); // Decrypt the ciphertexts let decrypted = self.cipher.decrypt(encrypted).await?; @@ -236,56 +234,120 @@ async fn init_cipher(config: &TandemConfig) -> Result { fn to_eql_encrypted( encrypted: Encrypted, identifier: &Identifier, -) -> Result { +) -> Result { debug!(target: ENCRYPT, msg = "Encrypted to EQL", ?identifier); match encrypted { Encrypted::Record(ciphertext, terms) => { - struct Indexes { - match_index: Option>, - ore_index: Option>, - unique_index: Option, - } - - let mut indexes = Indexes { - match_index: None, - ore_index: None, - unique_index: None, - }; + let mut match_index: Option> = None; + let mut ore_index: Option> = None; + let mut unique_index: Option = None; + let mut blake3_index: Option = None; + let mut ore_cclw_fixed_index: Option = None; + let mut ore_cclw_var_index: Option = None; + let mut selector: Option = None; for index_term in terms { match index_term { IndexTerm::Binary(bytes) => { - indexes.unique_index = Some(format_index_term_binary(&bytes)) + unique_index = Some(format_index_term_binary(&bytes)) } - IndexTerm::BitMap(inner) => indexes.match_index = Some(inner), - IndexTerm::OreArray(vec_of_bytes) => { - indexes.ore_index = Some(format_index_term_ore_array(&vec_of_bytes)); + IndexTerm::BitMap(inner) => match_index = Some(inner), + IndexTerm::OreArray(bytes) => { + ore_index = Some(format_index_term_ore_array(&bytes)); } IndexTerm::OreFull(bytes) => { - indexes.ore_index = Some(format_index_term_ore(&bytes)); + ore_index = Some(format_index_term_ore(&bytes)); } IndexTerm::OreLeft(bytes) => { - indexes.ore_index = Some(format_index_term_ore(&bytes)); + ore_index = Some(format_index_term_ore(&bytes)); } + IndexTerm::BinaryVec(_) => todo!(), + IndexTerm::SteVecSelector(s) => { + selector = Some(hex::encode(s.as_bytes())); + } + IndexTerm::SteVecTerm(ste_vec_term) => match ste_vec_term { + EncryptedSteVecTerm::Mac(bytes) => blake3_index = Some(hex::encode(bytes)), + EncryptedSteVecTerm::OreFixed(ore) => { + ore_cclw_fixed_index = Some(hex::encode(&ore)) + } + EncryptedSteVecTerm::OreVariable(ore) => { + ore_cclw_var_index = Some(hex::encode(&ore)) + } + }, + IndexTerm::SteQueryVec(_query) => {} // TODO: what do we do here? IndexTerm::Null => {} - _ => return Err(EncryptError::UnknownIndexTerm(identifier.to_owned()).into()), }; } - Ok(eql::Encrypted::Ciphertext { - ciphertext, + Ok(eql::EqlEncrypted { identifier: identifier.to_owned(), - match_index: indexes.match_index, - ore_index: indexes.ore_index, - unique_index: indexes.unique_index, version: 1, + body: EqlEncryptedBody { + ciphertext, + indexes: EqlEncryptedIndexes { + match_index, + ore_index, + unique_index, + blake3_index, + ore_cclw_fixed_index, + ore_cclw_var_index, + selector, + ste_vec_index: None, + }, + }, + }) + } + Encrypted::SteVec(ste_vec) => { + let ciphertext = ste_vec.root_ciphertext()?.clone(); + + let ste_vec_index: Vec = ste_vec + .into_iter() + .map(|EncryptedEntry(selector, term, ciphertext)| { + let indexes = match term { + EncryptedSteVecTerm::Mac(bytes) => EqlEncryptedIndexes { + selector: Some(hex::encode(selector.as_bytes())), + blake3_index: Some(hex::encode(bytes)), + ..Default::default() + }, + EncryptedSteVecTerm::OreFixed(ore) => EqlEncryptedIndexes { + selector: Some(hex::encode(selector.as_bytes())), + ore_cclw_fixed_index: Some(hex::encode(&ore)), + ..Default::default() + }, + EncryptedSteVecTerm::OreVariable(ore) => EqlEncryptedIndexes { + selector: Some(hex::encode(selector.as_bytes())), + ore_cclw_var_index: Some(hex::encode(&ore)), + ..Default::default() + }, + }; + + eql::EqlEncryptedBody { + ciphertext, + indexes, + } + }) + .collect(); + + // FIXME: I'm unsure if I've handled the root ciphertext correctly + // The way it's implemented right now is that it will be repeated one in the ste_vec_index. + Ok(eql::EqlEncrypted { + identifier: identifier.to_owned(), + version: 1, + body: EqlEncryptedBody { + ciphertext: ciphertext.clone(), + indexes: EqlEncryptedIndexes { + match_index: None, + ore_index: None, + unique_index: None, + blake3_index: None, + ore_cclw_fixed_index: None, + ore_cclw_var_index: None, + selector: None, + ste_vec_index: Some(ste_vec_index), + }, + }, }) } - Encrypted::SteVec(ste_vec_index) => Ok(eql::Encrypted::SteVec { - identifier: identifier.to_owned(), - ste_vec_index, - version: 1, - }), } } @@ -314,15 +376,6 @@ fn format_index_term_ore(bytes: &Vec) -> Vec { vec![format_index_term_ore_bytea(bytes)] } -fn eql_encrypted_to_encrypted_record( - eql_encrypted: eql::Encrypted, -) -> Result { - match eql_encrypted { - eql::Encrypted::Ciphertext { ciphertext, .. } => Ok(ciphertext), - eql::Encrypted::SteVec { ste_vec_index, .. } => ste_vec_index.into_root_ciphertext(), - } -} - fn plaintext_type_name(pt: Plaintext) -> String { match pt { Plaintext::BigInt(_) => "BigInt".to_string(), diff --git a/packages/cipherstash-proxy/src/encrypt/schema/manager.rs b/packages/cipherstash-proxy/src/encrypt/schema/manager.rs index fc000ccb..b2603a98 100644 --- a/packages/cipherstash-proxy/src/encrypt/schema/manager.rs +++ b/packages/cipherstash-proxy/src/encrypt/schema/manager.rs @@ -132,19 +132,18 @@ pub async fn load_schema(config: &DatabaseConfig) -> Result { let table_name: String = table.get("table_name"); let primary_keys: Vec> = table.get("primary_keys"); let columns: Vec = table.get("columns"); - let _types: Vec> = table.get("column_types"); - let domains: Vec> = table.get("column_domains"); + let column_type_names: Vec> = table.get("column_type_names"); let mut table = Table::new(Ident::new(&table_name)); - columns.iter().zip(domains).for_each(|(col, domain)| { + columns.iter().zip(column_type_names).for_each(|(col, column_type_name)| { let is_primary_key = primary_keys.contains(&Some(col.to_string())); let ident = Ident::with_quote('"', col); - let column = match domain.as_deref() { - Some("cs_encrypted_v1") => { - debug!(target: SCHEMA, msg = "cs_encrypted_v1 column", table = table_name, column = col); + let column = match column_type_name.as_deref() { + Some("eql_v1_encrypted") => { + debug!(target: SCHEMA, msg = "eql_v1_encrypted column", table = table_name, column = col); Column::eql(ident) } _ => Column::native(ident), diff --git a/packages/cipherstash-proxy/src/encrypt/sql/select_config.sql b/packages/cipherstash-proxy/src/encrypt/sql/select_config.sql index 72827f37..8be0732f 100644 --- a/packages/cipherstash-proxy/src/encrypt/sql/select_config.sql +++ b/packages/cipherstash-proxy/src/encrypt/sql/select_config.sql @@ -1 +1 @@ -SELECT data FROM cs_configuration_v1 WHERE state = 'active' LIMIT 1; +SELECT data FROM public.eql_v1_configuration WHERE state = 'active' LIMIT 1; diff --git a/packages/cipherstash-proxy/src/encrypt/sql/select_table_schemas.sql b/packages/cipherstash-proxy/src/encrypt/sql/select_table_schemas.sql index ee3ba513..88743f3e 100644 --- a/packages/cipherstash-proxy/src/encrypt/sql/select_table_schemas.sql +++ b/packages/cipherstash-proxy/src/encrypt/sql/select_table_schemas.sql @@ -3,8 +3,7 @@ SELECT t.table_name, array_agg(distinct k.column_name)::text[] AS primary_keys, array_agg(c.column_name)::text[] AS columns, - array_agg(c.data_type)::text[] AS column_types, - array_agg(c.domain_name)::text[] AS column_domains + array_agg(c.udt_name)::text[] AS column_type_names FROM information_schema.tables t LEFT JOIN @@ -24,3 +23,6 @@ GROUP BY t.table_schema, t.table_name ORDER BY t.table_schema, t.table_name; + + + diff --git a/packages/cipherstash-proxy/src/eql/mod.rs b/packages/cipherstash-proxy/src/eql/mod.rs index 69266310..03ec0141 100644 --- a/packages/cipherstash-proxy/src/eql/mod.rs +++ b/packages/cipherstash-proxy/src/eql/mod.rs @@ -1,7 +1,4 @@ -use cipherstash_client::{ - encryption::SteVec, - zerokms::{encrypted_record, EncryptedRecord}, -}; +use cipherstash_client::zerokms::{encrypted_record, EncryptedRecord}; use serde::{Deserialize, Serialize}; use sqltk::parser::ast::Ident; @@ -16,6 +13,7 @@ pub struct Plaintext { #[serde(rename = "q")] pub for_query: Option, } + #[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct Identifier { #[serde(rename = "t")] @@ -65,54 +63,57 @@ pub enum ForQuery { } #[derive(Debug, Deserialize, Serialize)] -#[serde(tag = "k")] -pub enum Encrypted { - #[serde(rename = "ct")] - Ciphertext { - #[serde(rename = "c", with = "encrypted_record::formats::mp_base85")] - ciphertext: EncryptedRecord, - #[serde(rename = "o")] - ore_index: Option>, - #[serde(rename = "m")] - match_index: Option>, - #[serde(rename = "u")] - unique_index: Option, - #[serde(rename = "i")] - identifier: Identifier, - #[serde(rename = "v")] - version: u16, - }, - #[serde(rename = "sv")] - SteVec { - #[serde(rename = "sv")] - ste_vec_index: SteVec<16>, - #[serde(rename = "i")] - identifier: Identifier, - #[serde(rename = "v")] - version: u16, - }, +pub struct EqlEncrypted { + #[serde(rename = "i")] + pub(crate) identifier: Identifier, + #[serde(rename = "v")] + pub(crate) version: u16, + + #[serde(flatten)] + pub(crate) body: EqlEncryptedBody, } -// fn ident_de<'de, D>(deserializer: D) -> Result -// where -// D: serde::Deserializer<'de>, -// { -// let s = String::deserialize(deserializer)?; -// Ok(Ident::with_quote('"', s)) -// } - -// fn ident_se(ident: &Ident, serializer: S) -> Result -// where -// S: Serializer, -// { -// let s = ident.to_string(); -// serializer.serialize_str(&s) -// } +#[derive(Debug, Deserialize, Serialize)] +pub struct EqlEncryptedBody { + #[serde(rename = "c", with = "encrypted_record::formats::mp_base85")] + pub(crate) ciphertext: EncryptedRecord, + + #[serde(flatten)] + pub(crate) indexes: EqlEncryptedIndexes, +} + +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct EqlEncryptedIndexes { + #[serde(rename = "o")] + pub(crate) ore_index: Option>, + #[serde(rename = "m")] + pub(crate) match_index: Option>, + #[serde(rename = "u")] + pub(crate) unique_index: Option, + + #[serde(rename = "s")] + pub(crate) selector: Option, + + #[serde(rename = "b")] + pub(crate) blake3_index: Option, + + #[serde(rename = "ocf")] + pub(crate) ore_cclw_fixed_index: Option, + #[serde(rename = "ocv")] + pub(crate) ore_cclw_var_index: Option, + + #[serde(rename = "sv")] + pub(crate) ste_vec_index: Option>, +} #[cfg(test)] mod tests { + use crate::{ + eql::{EqlEncryptedBody, EqlEncryptedIndexes}, + EqlEncrypted, + }; + use super::{Identifier, Plaintext}; - use crate::Encrypted; use cipherstash_client::zerokms::EncryptedRecord; use recipher::key::Iv; use uuid::Uuid; @@ -141,20 +142,28 @@ mod tests { pub fn ciphertext_json() { let expected = Identifier::new("table", "column"); - let ct = Encrypted::Ciphertext { + let ct = EqlEncrypted { identifier: expected.clone(), version: 1, - ciphertext: EncryptedRecord { - iv: Iv::default(), - ciphertext: vec![1; 32], - tag: vec![1; 16], - descriptor: "ciphertext".to_string(), - dataset_id: Some(Uuid::new_v4()), + body: EqlEncryptedBody { + ciphertext: EncryptedRecord { + iv: Iv::default(), + ciphertext: vec![1; 32], + tag: vec![1; 16], + descriptor: "ciphertext".to_string(), + dataset_id: Some(Uuid::new_v4()), + }, + indexes: EqlEncryptedIndexes { + ore_index: None, + match_index: None, + unique_index: None, + blake3_index: None, + selector: None, + ore_cclw_fixed_index: None, + ore_cclw_var_index: None, + ste_vec_index: None, + }, }, - - ore_index: None, - match_index: None, - unique_index: None, }; let value = serde_json::to_value(&ct).unwrap(); @@ -163,12 +172,7 @@ mod tests { let t = &i["t"]; assert_eq!(t, "table"); - let result: Encrypted = serde_json::from_value(value).unwrap(); - - if let Encrypted::Ciphertext { identifier, .. } = result { - assert_eq!(expected, identifier); - } else { - panic!("Expected Encrypted::Ciphertext"); - } + let result: EqlEncrypted = serde_json::from_value(value).unwrap(); + assert_eq!(expected, result.identifier); } } diff --git a/packages/cipherstash-proxy/src/lib.rs b/packages/cipherstash-proxy/src/lib.rs index 202a07e8..2e81ca9e 100644 --- a/packages/cipherstash-proxy/src/lib.rs +++ b/packages/cipherstash-proxy/src/lib.rs @@ -15,7 +15,7 @@ pub use crate::cli::Args; pub use crate::cli::Migrate; pub use crate::config::{DatabaseConfig, ServerConfig, TandemConfig, TlsConfig}; pub use crate::encrypt::Encrypt; -pub use crate::eql::{Encrypted, ForQuery, Identifier, Plaintext}; +pub use crate::eql::{EqlEncrypted, ForQuery, Identifier, Plaintext}; pub use crate::log::init; use std::mem; diff --git a/packages/cipherstash-proxy/src/postgresql/backend.rs b/packages/cipherstash-proxy/src/postgresql/backend.rs index 76b2e74b..950f5a3d 100644 --- a/packages/cipherstash-proxy/src/postgresql/backend.rs +++ b/packages/cipherstash-proxy/src/postgresql/backend.rs @@ -6,7 +6,7 @@ use super::messages::row_description::RowDescription; use super::messages::BackendCode; use crate::connect::Sender; use crate::encrypt::Encrypt; -use crate::eql::Encrypted; +use crate::eql::EqlEncrypted; use crate::error::Error; use crate::log::{DEVELOPMENT, MAPPER, PROTOCOL}; use crate::postgresql::context::Portal; @@ -262,7 +262,7 @@ where let result_column_format_codes = portal.format_codes(result_column_count); // Each row is converted into Vec> - let ciphertexts: Vec> = rows + let ciphertexts: Vec> = rows .iter() .map(|row| row.to_ciphertext()) .flatten_ok() diff --git a/packages/cipherstash-proxy/src/postgresql/frontend.rs b/packages/cipherstash-proxy/src/postgresql/frontend.rs index 681b7f50..a3e87a3d 100644 --- a/packages/cipherstash-proxy/src/postgresql/frontend.rs +++ b/packages/cipherstash-proxy/src/postgresql/frontend.rs @@ -23,7 +23,7 @@ use crate::prometheus::{ STATEMENTS_ENCRYPTED_TOTAL, STATEMENTS_PASSTHROUGH_TOTAL, STATEMENTS_TOTAL, STATEMENTS_UNMAPPABLE_TOTAL, }; -use crate::Encrypted; +use crate::EqlEncrypted; use bytes::BytesMut; use cipherstash_client::encryption::Plaintext; use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypeCheckedStatement}; @@ -359,7 +359,7 @@ where &mut self, typed_statement: &TypeCheckedStatement<'_>, literal_columns: &Vec>, - ) -> Result>, Error> { + ) -> Result>, Error> { let literal_values = typed_statement.literal_values(); if literal_values.is_empty() { debug!(target: MAPPER, @@ -404,7 +404,7 @@ where async fn transform_statement( &mut self, typed_statement: &TypeCheckedStatement<'_>, - encrypted_literals: &Vec>, + encrypted_literals: &Vec>, ) -> Result, Error> { // Convert literals to ast Expr let mut encrypted_expressions = vec![]; @@ -704,7 +704,7 @@ where &mut self, bind: &Bind, statement: &Statement, - ) -> Result>, Error> { + ) -> Result>, Error> { let plaintexts = bind.to_plaintext(&statement.param_columns, &statement.postgres_param_types)?; diff --git a/packages/cipherstash-proxy/src/postgresql/messages/bind.rs b/packages/cipherstash-proxy/src/postgresql/messages/bind.rs index e259b3b6..a76a02f6 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/bind.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/bind.rs @@ -76,7 +76,7 @@ impl Bind { Ok(plaintexts) } - pub fn rewrite(&mut self, encrypted: Vec>) -> Result<(), Error> { + pub fn rewrite(&mut self, encrypted: Vec>) -> Result<(), Error> { for (idx, ct) in encrypted.iter().enumerate() { if let Some(ct) = ct { let json = serde_json::to_value(ct)?; diff --git a/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs b/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs index e12274fe..f9e4112e 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs @@ -19,7 +19,7 @@ pub struct DataColumn { } impl DataRow { - pub fn to_ciphertext(&self) -> Result>, Error> { + pub fn to_ciphertext(&self) -> Result>, Error> { Ok(self.columns.iter().map(|col| col.into()).collect()) } @@ -159,7 +159,7 @@ impl TryFrom for BytesMut { } } -impl From<&DataColumn> for Option { +impl From<&DataColumn> for Option { fn from(col: &DataColumn) -> Self { debug!(target: MAPPER, data_column = ?col); match col.json_bytes() { diff --git a/packages/cipherstash-proxy/src/postgresql/messages/parse.rs b/packages/cipherstash-proxy/src/postgresql/messages/parse.rs index faf0ec6b..a2c30f5c 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/parse.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/parse.rs @@ -24,11 +24,11 @@ impl Parse { } /// - /// Encrypted columns are the cs_encrypted_v1 Domain Type - /// cs_encrypted_v1 wraps JSONB + /// Encrypted columns are the eql_v1_encrypted Domain Type + /// eql_v1_encrypted wraps JSONB /// - /// Using JSONB to avoid the complexity of loading the OID of cs_encrypted_v1 - /// PostgreSQL will coerce JSONB to cs_encrypted_v1 if it passes the constaint check + /// Using JSONB to avoid the complexity of loading the OID of eql_v1_encrypted + /// PostgreSQL will coerce JSONB to eql_v1_encrypted if it passes the constaint check /// pub fn rewrite_param_types(&mut self, columns: &[Option]) { for (idx, col) in columns.iter().enumerate() { diff --git a/packages/eql-mapper/Cargo.toml b/packages/eql-mapper/Cargo.toml index febb293e..3c63989e 100644 --- a/packages/eql-mapper/Cargo.toml +++ b/packages/eql-mapper/Cargo.toml @@ -19,6 +19,7 @@ sqltk = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } +vec1 = "1.12.1" [dev-dependencies] pretty_assertions = "^1.0" diff --git a/packages/eql-mapper/src/inference/infer_type_impls/expr.rs b/packages/eql-mapper/src/inference/infer_type_impls/expr.rs index 451c0d79..001572d6 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/expr.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/expr.rs @@ -158,7 +158,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { | BinaryOperator::HashArrow | BinaryOperator::HashLongArrow | BinaryOperator::AtAt - | BinaryOperator::HashMinus + | BinaryOperator::HashMinus // TODO do not support for EQL | BinaryOperator::AtQuestion | BinaryOperator::Question | BinaryOperator::QuestionAnd diff --git a/packages/eql-mapper/src/inference/infer_type_impls/function.rs b/packages/eql-mapper/src/inference/infer_type_impls/function.rs index 88347df6..585d6b7e 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/function.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/function.rs @@ -1,12 +1,16 @@ use eql_mapper_macros::trace_infer; -use sqltk::parser::ast::{Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident}; +use sqltk::parser::ast::{Function, FunctionArguments}; use crate::{ - inference::{type_error::TypeError, InferType}, - unifier::Type, - SqlIdent, TypeInferencer, + get_type_signature_for_special_cased_sql_function, inference::infer_type::InferType, + CompoundIdent, FunctionSig, TypeError, TypeInferencer, }; +/// Looks up the function signature. +/// +/// If a signature is found it means that function is handled as an EQL special case and is type checked accordingly. +/// +/// If a signature is not found then all function args and its return type are unified as native. #[trace_infer] impl<'ast> InferType<'ast, Function> for TypeInferencer<'ast> { fn infer_exit(&mut self, function: &'ast Function) -> Result<(), TypeError> { @@ -17,115 +21,14 @@ impl<'ast> InferType<'ast, Function> for TypeInferencer<'ast> { } let Function { name, args, .. } = function; + let fn_name = CompoundIdent::from(&name.0); - let fn_name: Vec<_> = name.0.iter().map(SqlIdent).collect(); - - if fn_name == [SqlIdent(&Ident::new("min"))] || fn_name == [SqlIdent(&Ident::new("max"))] { - // 1. There MUST be one unnamed argument (it CAN come from a subquery) - // 2. The return type is the same as the argument type - - match args { - FunctionArguments::None => { - return Err(TypeError::FunctionCall(format!( - "{} should be called with 1 argument, got 0", - fn_name.last().unwrap() - ))) - } - - FunctionArguments::Subquery(query) => { - // The query must return a single column projection which has the same type as the result of the - // call to min/max. - self.unify_node_with_type( - &**query, - Type::projection(&[(self.get_node_type(function), None)]), - )?; - } - - FunctionArguments::List(args_list) => { - if args_list.args.len() == 1 { - match &args_list.args[0] { - FunctionArg::Named { .. } | FunctionArg::ExprNamed { .. } => { - return Err(TypeError::FunctionCall(format!( - "{} cannot be called with named arguments", - fn_name.last().unwrap(), - ))) - } - - FunctionArg::Unnamed(function_arg_expr) => match function_arg_expr { - FunctionArgExpr::Expr(expr) => { - self.unify_nodes(function, expr)?; - } - - FunctionArgExpr::QualifiedWildcard(_) - | FunctionArgExpr::Wildcard => { - return Err(TypeError::FunctionCall(format!( - "{} cannot be called with wildcard arguments", - fn_name.last().unwrap(), - ))) - } - }, - } - } else { - return Err(TypeError::FunctionCall(format!( - "{} should be called with 1 argument, got {}", - fn_name.last().unwrap(), - args_list.args.len() - ))); - } - } + match get_type_signature_for_special_cased_sql_function(&fn_name, args) { + Some(sig) => { + sig.instantiate(&*self).apply_constraints(self, function)?; } - } else { - // All other functions: resolve to native - // EQL values will be rejected in function calls - self.unify_node_with_type(function, Type::any_native())?; - - match args { - // Function called without any arguments. - // Used for functions like `CURRENT_TIMESTAMP` that do not require parentheses () - // This is not the same as a function that has zero arguments (which would be an empty arg list) - FunctionArguments::None => {} - - FunctionArguments::Subquery(query) => { - // The query must return a single column projection which has the same type as the result of the function - self.unify_node_with_type( - &**query, - Type::projection(&[(self.get_node_type(function), None)]), - )?; - } - - FunctionArguments::List(args_list) => { - self.unify_node_with_type(function, Type::any_native())?; - for arg in &args_list.args { - match arg { - FunctionArg::ExprNamed { - name, - arg, - operator: _, - } => { - self.unify_node_with_type(name, Type::any_native())?; - match arg { - FunctionArgExpr::Expr(expr) => { - self.unify_node_with_type(expr, Type::any_native())?; - } - // Aggregate functions like COUNT(table.*) - FunctionArgExpr::QualifiedWildcard(_) => {} - // Aggregate functions like COUNT(*) - FunctionArgExpr::Wildcard => {} - } - } - FunctionArg::Named { arg, .. } | FunctionArg::Unnamed(arg) => match arg - { - FunctionArgExpr::Expr(expr) => { - self.unify_node_with_type(expr, Type::any_native())?; - } - // Aggregate functions like COUNT(table.*) - FunctionArgExpr::QualifiedWildcard(_) => {} - // Aggregate functions like COUNT(*) - FunctionArgExpr::Wildcard => {} - }, - } - } - } + None => { + FunctionSig::instantiate_native(function).apply_constraints(self, function)?; } } diff --git a/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs b/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs new file mode 100644 index 00000000..89e5d5df --- /dev/null +++ b/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs @@ -0,0 +1,24 @@ +use eql_mapper_macros::trace_infer; +use sqltk::parser::ast::FunctionArgExpr; + +use crate::{inference::infer_type::InferType, TypeError, TypeInferencer}; + +#[trace_infer] +impl<'ast> InferType<'ast, FunctionArgExpr> for TypeInferencer<'ast> { + fn infer_exit(&mut self, farg_expr: &'ast FunctionArgExpr) -> Result<(), TypeError> { + let farg_expr_ty = self.get_node_type(farg_expr); + match farg_expr { + FunctionArgExpr::Expr(expr) => { + self.unify(farg_expr_ty, self.get_node_type(expr))?; + } + FunctionArgExpr::QualifiedWildcard(qualified) => { + self.unify(farg_expr_ty, self.resolve_qualified_wildcard(&qualified.0)?)?; + } + FunctionArgExpr::Wildcard => { + self.unify(farg_expr_ty, self.resolve_wildcard()?)?; + } + }; + + Ok(()) + } +} diff --git a/packages/eql-mapper/src/inference/infer_type_impls/mod.rs b/packages/eql-mapper/src/inference/infer_type_impls/mod.rs index 1f266dee..103a8cfd 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/mod.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/mod.rs @@ -1,6 +1,7 @@ // General AST nodes mod expr; mod function; +mod function_arg_expr; mod select; mod select_item; mod select_items; diff --git a/packages/eql-mapper/src/inference/mod.rs b/packages/eql-mapper/src/inference/mod.rs index 2e364944..b5a96a15 100644 --- a/packages/eql-mapper/src/inference/mod.rs +++ b/packages/eql-mapper/src/inference/mod.rs @@ -2,6 +2,8 @@ mod infer_type; mod infer_type_impls; mod registry; mod sequence; +mod sql_fn_macros; +mod sql_functions; mod type_error; pub mod unifier; @@ -12,7 +14,8 @@ use std::{cell::RefCell, fmt::Debug, marker::PhantomData, ops::ControlFlow, rc:: use infer_type::InferType; use sqltk::parser::ast::{ - Delete, Expr, Function, Ident, Insert, Query, Select, SelectItem, SetExpr, Statement, Values, + Delete, Expr, Function, FunctionArgExpr, Ident, Insert, Query, Select, SelectItem, SetExpr, + Statement, Values, }; use sqltk::{into_control_flow, AsNodeKey, Break, Visitable, Visitor}; @@ -20,6 +23,7 @@ use crate::{ScopeError, ScopeTracker, TableResolver}; pub(crate) use registry::*; pub(crate) use sequence::*; +pub(crate) use sql_functions::*; pub(crate) use type_error::*; /// [`Visitor`] implementation that performs type inference on AST nodes. @@ -187,6 +191,7 @@ macro_rules! dispatch_all { dispatch!($self, $method, $node, Vec); dispatch!($self, $method, $node, SelectItem); dispatch!($self, $method, $node, Function); + dispatch!($self, $method, $node, FunctionArgExpr); dispatch!($self, $method, $node, Values); dispatch!($self, $method, $node, sqltk::parser::ast::Value); }; diff --git a/packages/eql-mapper/src/inference/sql_fn_macros.rs b/packages/eql-mapper/src/inference/sql_fn_macros.rs new file mode 100644 index 00000000..4fcd9bc7 --- /dev/null +++ b/packages/eql-mapper/src/inference/sql_fn_macros.rs @@ -0,0 +1,30 @@ +#[macro_export] +macro_rules! to_kind { + (NATIVE) => { + $crate::Kind::Native + }; + ($generic:ident) => { + $crate::Kind::Generic(stringify!($generic)) + }; +} + +#[macro_export] +macro_rules! sql_fn_args { + (()) => { vec![] }; + + (($arg:ident)) => { vec![$crate::to_kind!($arg)] }; + + (($arg:ident $(,$rest:ident)*)) => { + vec![$crate::to_kind!($arg) $(, $crate::to_kind!($rest))*] + }; +} + +#[macro_export] +macro_rules! sql_fn { + ($name:ident $args:tt -> $return_kind:ident) => { + $crate::SqlFunction::new( + stringify!($name), + FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)), + ) + }; +} diff --git a/packages/eql-mapper/src/inference/sql_functions.rs b/packages/eql-mapper/src/inference/sql_functions.rs new file mode 100644 index 00000000..7a2d42d4 --- /dev/null +++ b/packages/eql-mapper/src/inference/sql_functions.rs @@ -0,0 +1,239 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, LazyLock}, +}; + +use derive_more::derive::Display; +use sqltk::parser::ast::{Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident}; + +use itertools::Itertools; +use vec1::{vec1, Vec1}; + +use crate::{sql_fn, unifier::Type, SqlIdent, TypeInferencer}; + +use super::TypeError; + +/// The identifier and type signature of a SQL function. +/// +/// See [`SQL_FUNCTION_SIGNATURES`]. +#[derive(Debug)] +pub(crate) struct SqlFunction(CompoundIdent, FunctionSig); + +/// A representation of the type of an argument or return type in a SQL function. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub(crate) enum Kind { + /// A type that mjust be a native type + Native, + + /// A type that can be a native or EQL type. The `str` is the generic variable name. + Generic(&'static str), +} + +/// The type signature of a SQL functon (excluding its name). +#[derive(Debug, Clone)] +pub(crate) struct FunctionSig { + args: Vec, + return_type: Kind, + generics: HashSet<&'static str>, +} + +/// A function signature but filled in with fresh type variables that correspond with the [`Kind`] or each argument and +/// return type. +#[derive(Debug, Clone)] +pub(crate) struct InstantiatedSig { + args: Vec>, + return_type: Arc, +} + +impl FunctionSig { + fn new(args: Vec, return_type: Kind) -> Self { + let mut generics: HashSet<&'static str> = HashSet::new(); + + for arg in &args { + if let Kind::Generic(generic) = arg { + generics.insert(*generic); + } + } + + if let Kind::Generic(generic) = return_type { + generics.insert(generic); + } + + Self { + args, + return_type, + generics, + } + } + + /// Checks if `self` is applicable to a particular piece of SQL function invocation syntax. + pub(crate) fn is_applicable_to_args(&self, fn_args_syntax: &FunctionArguments) -> bool { + match fn_args_syntax { + FunctionArguments::None => self.args.is_empty(), + FunctionArguments::Subquery(_) => self.args.len() == 1, + FunctionArguments::List(fn_args) => self.args.len() == fn_args.args.len(), + } + } + + /// Creates an [`InstantiatedSig`] from `self`, filling in the [`Kind`]s with fresh type variables. + pub(crate) fn instantiate(&self, inferencer: &TypeInferencer<'_>) -> InstantiatedSig { + let mut generics: HashMap<&'static str, Arc> = HashMap::new(); + + for generic in self.generics.iter() { + generics.insert(generic, inferencer.fresh_tvar()); + } + + InstantiatedSig { + args: self + .args + .iter() + .map(|kind| match kind { + Kind::Native => Arc::new(Type::any_native()), + Kind::Generic(generic) => generics[generic].clone(), + }) + .collect(), + + return_type: match self.return_type { + Kind::Native => Arc::new(Type::any_native()), + Kind::Generic(generic) => generics[generic].clone(), + }, + } + } + + /// For functions that do not have special case handling we synthesise an [`InstatiatedSig`] from the SQL function + /// invocation synta where all arguments and the return types are native. + pub(crate) fn instantiate_native(function: &Function) -> InstantiatedSig { + let arg_count = match &function.args { + FunctionArguments::None => 0, + FunctionArguments::Subquery(_) => 1, + FunctionArguments::List(args) => args.args.len(), + }; + + let args: Vec> = (0..arg_count) + .map(|_| Arc::new(Type::any_native())) + .collect(); + + InstantiatedSig { + args, + return_type: Arc::new(Type::any_native()), + } + } +} + +impl InstantiatedSig { + /// Applies the type constraints of the function to to the AST. + pub(crate) fn apply_constraints<'ast>( + &self, + inferencer: &mut TypeInferencer<'ast>, + function: &'ast Function, + ) -> Result<(), TypeError> { + let fn_name = CompoundIdent::from(&function.name.0); + + inferencer.unify_node_with_type(function, self.return_type.clone())?; + + match &function.args { + FunctionArguments::None => { + if self.args.is_empty() { + Ok(()) + } else { + Err(TypeError::Conflict(format!( + "expected {} args to function {}; got 0", + self.args.len(), + fn_name + ))) + } + } + + FunctionArguments::Subquery(query) => { + if self.args.len() == 1 { + inferencer.unify_node_with_type(&**query, self.args[0].clone())?; + Ok(()) + } else { + Err(TypeError::Conflict(format!( + "expected {} args to function {}; got 0", + self.args.len(), + fn_name + ))) + } + } + + FunctionArguments::List(args) => { + for (sig_arg, fn_arg) in self.args.iter().zip(args.args.iter()) { + let farg_expr = get_function_arg_expr(fn_arg); + inferencer.unify_node_with_type(farg_expr, sig_arg.clone())?; + } + + Ok(()) + } + } + } +} + +fn get_function_arg_expr(fn_arg: &FunctionArg) -> &FunctionArgExpr { + match fn_arg { + FunctionArg::Named { arg, .. } => arg, + FunctionArg::ExprNamed { arg, .. } => arg, + FunctionArg::Unnamed(arg) => arg, + } +} + +impl SqlFunction { + fn new(ident: &str, sig: FunctionSig) -> Self { + Self(CompoundIdent::from(ident), sig) + } +} + +#[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Clone, Display)] +#[display("{}", _0.iter().map(SqlIdent::to_string).collect::>().join("."))] +pub(crate) struct CompoundIdent(Vec1>); + +impl From<&str> for CompoundIdent { + fn from(value: &str) -> Self { + CompoundIdent(vec1![SqlIdent(Ident::new(value))]) + } +} + +impl From<&Vec> for CompoundIdent { + fn from(value: &Vec) -> Self { + let mut idents = Vec1::>::new(SqlIdent(value[0].clone())); + idents.extend(value[1..].iter().cloned().map(SqlIdent)); + CompoundIdent(idents) + } +} + +/// SQL functions that are handled with special case type checking rules. +static SQL_FUNCTION_SIGNATURES: LazyLock>> = + LazyLock::new(|| { + // Notation: a single uppercase letter denotes an unknown type. Matching letters in a signature will be assigned + // *the same type variable* and thus must resolve to the same type. (🙏 Haskell) + // + // Eventually we should type check EQL types against their configured indexes instead of leaving that to the EQL + // extension in the database. I can imagine supporting type bounds in signatures here, such as: `T: Eq` + let sql_fns = vec![ + sql_fn!(count(T) -> NATIVE), + sql_fn!(min(T) -> T), + sql_fn!(max(T) -> T), + sql_fn!(jsonb_path_query(T, T) -> T), + sql_fn!(jsonb_path_query_first(T, T) -> T), + sql_fn!(jsonb_path_exists(T, T) -> T), + ]; + + let mut sql_fns_by_name: HashMap> = HashMap::new(); + + for (key, chunk) in &sql_fns.into_iter().chunk_by(|sql_fn| sql_fn.0.clone()) { + sql_fns_by_name.insert( + key.clone(), + chunk.into_iter().map(|sql_fn| sql_fn.1).collect(), + ); + } + + sql_fns_by_name + }); + +pub(crate) fn get_type_signature_for_special_cased_sql_function( + fn_name: &CompoundIdent, + args: &FunctionArguments, +) -> Option<&'static FunctionSig> { + let sigs = SQL_FUNCTION_SIGNATURES.get(fn_name)?; + sigs.iter().find(|sig| sig.is_applicable_to_args(args)) +} diff --git a/packages/eql-mapper/src/lib.rs b/packages/eql-mapper/src/lib.rs index 3bd45186..d23d0719 100644 --- a/packages/eql-mapper/src/lib.rs +++ b/packages/eql-mapper/src/lib.rs @@ -33,6 +33,7 @@ mod test { use super::type_check; use crate::col; use crate::projection; + use crate::test_helpers; use crate::Param; use crate::Schema; use crate::TableResolver; @@ -41,9 +42,11 @@ mod test { Value, }; use pretty_assertions::assert_eq; + use sqltk::parser::ast::Ident; use sqltk::parser::ast::Statement; use sqltk::parser::ast::{self as ast}; use sqltk::AsNodeKey; + use sqltk::NodeKey; use std::collections::HashMap; use std::sync::Arc; use tracing::error; @@ -1017,27 +1020,63 @@ mod test { )] ); - let transformed_statement = match typed.transform(HashMap::from_iter([( + match typed.transform(HashMap::from_iter([( typed.literals[0].1.as_node_key(), ast::Value::SingleQuotedString("ENCRYPTED".into()), )])) { - Ok(transformed_statement) => transformed_statement, + Ok(transformed_statement) => assert_eq!( + transformed_statement.to_string(), + "SELECT * FROM employees WHERE salary > ROW('ENCRYPTED'::JSONB)" + ), Err(err) => panic!("statement transformation failed: {}", err), }; + } + + #[test] + fn insert_with_literal_subsitution() { + // init_tracing(); + + let schema = resolver(schema! { + tables: { + employees: { + id, + salary (EQL), + } + } + }); + + let statement = parse( + r#" + insert into employees (salary) values (20000) + "#, + ); - // This type checks the transformed statement so we can get hold of the encrypted literal. - let typed = match type_check(schema, &transformed_statement) { + let typed = match type_check(schema.clone(), &statement) { Ok(typed) => typed, Err(err) => panic!("type check failed: {:#?}", err), }; - assert!(typed.literals.contains(&( - EqlValue(TableColumn { - table: id("employees"), - column: id("salary") - }), - &ast::Value::SingleQuotedString("ENCRYPTED".into()), - ))); + assert_eq!( + typed.literals, + vec![( + EqlValue(TableColumn { + table: id("employees"), + column: id("salary") + }), + &ast::Value::Number(20000.into(), false) + )] + ); + + match typed.transform(HashMap::from_iter([( + typed.literals[0].1.as_node_key(), + ast::Value::SingleQuotedString("ENCRYPTED".into()), + )])) { + Ok(transformed_statement) => assert_eq!( + transformed_statement.to_string(), + "INSERT INTO employees (salary) VALUES (ROW('ENCRYPTED'::JSONB))" + ), + Err(err) => panic!("statement transformation failed: {}", err), + }; } #[test] @@ -1336,14 +1375,173 @@ mod test { }); let statement = - parse("SELECT MIN(salary), MAX(salary), department FROM employees GROUP BY department"); + parse("SELECT min(salary), max(salary), department FROM employees GROUP BY department"); match type_check(schema, &statement) { Ok(typed) => { match typed.transform(HashMap::new()) { Ok(statement) => assert_eq!( statement.to_string(), - "SELECT CS_MIN_V1(salary) AS MIN, CS_MAX_V1(salary) AS MAX, department FROM employees GROUP BY department".to_string() + "SELECT CS_MIN_V1(salary) AS min, CS_MAX_V1(salary) AS max, department FROM employees GROUP BY department".to_string() + ), + Err(err) => panic!("transformation failed: {err}"), + } + } + Err(err) => panic!("type check failed: {err}"), + } + } + + #[test] + fn jsonb_operator_arrow() { + test_jsonb_operator("->"); + } + + #[test] + fn jsonb_operator_long_arrow() { + test_jsonb_operator("->>"); + } + + #[test] + fn jsonb_operator_hash_arrow() { + test_jsonb_operator("#>"); + } + + #[test] + fn jsonb_operator_hash_long_arrow() { + test_jsonb_operator("#>>"); + } + + #[test] + fn jsonb_operator_hash_at_at() { + test_jsonb_operator("@@"); + } + + #[test] + fn jsonb_operator_at_question() { + test_jsonb_operator("@?"); + } + + #[test] + fn jsonb_operator_question() { + test_jsonb_operator("?"); + } + + #[test] + fn jsonb_operator_question_and() { + test_jsonb_operator("?&"); + } + + #[test] + fn jsonb_operator_question_pipe() { + test_jsonb_operator("?|"); + } + + #[test] + fn jsonb_operator_at_arrow() { + test_jsonb_operator("@>"); + } + + #[test] + fn jsonb_operator_arrow_at() { + test_jsonb_operator("<@"); + } + + #[test] + fn jsonb_function_jsonb_path_query() { + test_jsonb_function( + "jsonb_path_query", + vec![ + ast::Expr::Identifier(Ident::new("notes")), + ast::Expr::Value(ast::Value::SingleQuotedString("$.medications".to_owned())), + ], + ); + } + + // TODO: do we need to check that the RHS of JSON operators MUST be a Value node + // and not an arbitrary expression? + + fn test_jsonb_function(fn_name: &str, args: Vec) { + let schema = resolver(schema! { + tables: { + patients: { + id (PK), + notes (EQL), + } + } + }); + + let args_in = args + .iter() + .map(|expr| expr.to_string()) + .collect::>() + .join(", "); + + let statement = parse(&format!( + "SELECT id, {}({}) AS meds FROM patients", + fn_name, args_in + )); + + let args_encrypted = args + .iter() + .map(|expr| match expr { + ast::Expr::Identifier(ident) => ident.to_string(), + ast::Expr::Value(ast::Value::SingleQuotedString(s)) => { + format!("ROW(''::JSONB)", s) + } + _ => panic!("unsupported expr type in test util"), + }) + .collect::>() + .join(", "); + + let mut encrypted_literals: HashMap, ast::Value> = HashMap::new(); + + for arg in args.iter() { + if let ast::Expr::Value(value) = arg { + encrypted_literals.extend(test_helpers::dummy_encrypted_json_selector( + &statement, + value.clone(), + )); + } + } + + match type_check(schema, &statement) { + Ok(typed) => match typed.transform(encrypted_literals) { + Ok(statement) => { + assert_eq!( + statement.to_string(), + format!( + "SELECT id, {}({}) AS meds FROM patients", + fn_name, args_encrypted + ) + ) + } + Err(err) => panic!("transformation failed: {err}"), + }, + Err(err) => panic!("type check failed: {err}"), + } + } + + fn test_jsonb_operator(op: &str) { + let schema = resolver(schema! { + tables: { + patients: { + id (PK), + notes (EQL), + } + } + }); + + let statement = parse(&format!( + "SELECT id, notes {} 'medications' AS meds FROM patients", + op + )); + + match type_check(schema, &statement) { + Ok(typed) => { + match typed.transform(test_helpers::dummy_encrypted_json_selector(&statement, ast::Value::SingleQuotedString("medications".to_owned()))) { + Ok(statement) => assert_eq!( + statement.to_string(), + format!("SELECT id, notes {} ROW(''::JSONB) AS meds FROM patients", op) ), Err(err) => panic!("transformation failed: {err}"), } diff --git a/packages/eql-mapper/src/model/sql_ident.rs b/packages/eql-mapper/src/model/sql_ident.rs index acaa91da..aa5a0a3b 100644 --- a/packages/eql-mapper/src/model/sql_ident.rs +++ b/packages/eql-mapper/src/model/sql_ident.rs @@ -102,14 +102,28 @@ impl SqlIdent { } } -// This manual Hash implementation is required to prevent a clippy error: -// "error: you are deriving `Hash` but have implemented `PartialEq` explicitly" -impl Hash for SqlIdent -where - T: Hash, -{ +// This Hash implementation (and the following) one is required in order to be consistent with PartialEq. +impl Hash for SqlIdent<&Ident> { + fn hash(&self, state: &mut H) { + match self.0.quote_style { + Some(ch) => { + state.write_u8(1); + state.write_u32(ch as u32); + state.write(self.0.value.as_bytes()); + } + None => { + state.write_u8(0); + for ch in self.0.value.chars().flat_map(|ch| ch.to_lowercase()) { + state.write_u32(ch as u32); + } + } + } + } +} + +impl Hash for SqlIdent { fn hash(&self, state: &mut H) { - self.0.hash(state) + SqlIdent(&self.0).hash(state) } } diff --git a/packages/eql-mapper/src/model/type_system.rs b/packages/eql-mapper/src/model/type_system.rs index 7ea05ce9..a4ef7fe4 100644 --- a/packages/eql-mapper/src/model/type_system.rs +++ b/packages/eql-mapper/src/model/type_system.rs @@ -61,6 +61,13 @@ impl Projection { Projection::WithColumns(columns) } } + + pub fn type_at_col_index(&self, index: usize) -> Option<&Value> { + match self { + Projection::WithColumns(cols) => cols.get(index).map(|col| &col.ty), + Projection::Empty => None, + } + } } /// A column from a projection which has a type and an optional alias. diff --git a/packages/eql-mapper/src/test_helpers.rs b/packages/eql-mapper/src/test_helpers.rs index cf36ba46..1d9285b2 100644 --- a/packages/eql-mapper/src/test_helpers.rs +++ b/packages/eql-mapper/src/test_helpers.rs @@ -1,9 +1,12 @@ -use std::fmt::Debug; +use std::{collections::HashMap, convert::Infallible, fmt::Debug, ops::ControlFlow}; -use sqltk::parser::{ - ast::{self as ast, Statement}, - dialect::PostgreSqlDialect, - parser::Parser, +use sqltk::{ + parser::{ + ast::{self as ast, Statement, Value}, + dialect::PostgreSqlDialect, + parser::Parser, + }, + AsNodeKey, Break, NodeKey, Visitable, Visitor, }; use tracing_subscriber::fmt::format; use tracing_subscriber::fmt::format::FmtSpan; @@ -27,7 +30,7 @@ pub(crate) fn init_tracing() { }); } -pub(crate) fn parse(statement: &'static str) -> Statement { +pub(crate) fn parse(statement: &str) -> Statement { Parser::parse_sql(&PostgreSqlDialect {}, statement).unwrap()[0].clone() } @@ -35,6 +38,61 @@ pub(crate) fn id(ident: &str) -> ast::Ident { ast::Ident::from(ident) } +pub(crate) fn get_node_key_of_json_selector<'ast>( + statement: &'ast Statement, + selector: &Value, +) -> NodeKey<'ast> { + find_nodekey_for_value_node(statement, selector.clone()) + .expect("could not find selector Value node") +} + +pub(crate) fn dummy_encrypted_json_selector( + statement: &Statement, + selector: Value, +) -> HashMap, ast::Value> { + if let Value::SingleQuotedString(s) = &selector { + HashMap::from_iter(vec![( + get_node_key_of_json_selector(statement, &selector), + ast::Value::SingleQuotedString(format!("", s)), + )]) + } else { + panic!("dummy_encrypted_json_selector only works on Value::SingleQuotedString") + } +} + +/// Utility for finding the [`NodeKey`] of a [`Value`] node in `statement` by providing a `matching` equal node to search for. +pub(crate) fn find_nodekey_for_value_node( + statement: &Statement, + matching: ast::Value, +) -> Option> { + struct FindNode<'ast> { + needle: ast::Value, + found: Option>, + } + + impl<'a> Visitor<'a> for FindNode<'a> { + type Error = Infallible; + + fn enter(&mut self, node: &'a N) -> ControlFlow> { + if let Some(haystack) = node.downcast_ref::() { + if haystack == &self.needle { + self.found = Some(haystack.as_node_key()); + return ControlFlow::Break(Break::Finished); + } + } + ControlFlow::Continue(()) + } + } + + let mut visitor = FindNode { + needle: matching, + found: None, + }; + + statement.accept(&mut visitor); + + visitor.found +} #[macro_export] macro_rules! col { ((NATIVE)) => { diff --git a/packages/eql-mapper/src/transformation_rules/replace_plaintext_eql_literals.rs b/packages/eql-mapper/src/transformation_rules/replace_plaintext_eql_literals.rs index 687649af..a153b115 100644 --- a/packages/eql-mapper/src/transformation_rules/replace_plaintext_eql_literals.rs +++ b/packages/eql-mapper/src/transformation_rules/replace_plaintext_eql_literals.rs @@ -1,6 +1,9 @@ use std::{any::type_name, collections::HashMap}; -use sqltk::parser::ast::Value; +use sqltk::parser::ast::{ + CastKind, DataType, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList, + FunctionArguments, Ident, ObjectName, Value, +}; use sqltk::{NodeKey, NodePath, Visitable}; use crate::EqlMapperError; @@ -25,10 +28,10 @@ impl<'ast> TransformationRule<'ast> for ReplacePlaintextEqlLiterals<'ast> { target_node: &mut N, ) -> Result { if self.would_edit(node_path, target_node) { - if let Some((value,)) = node_path.last_1_as::() { + if let Some((Expr::Value(value),)) = node_path.last_1_as::() { if let Some(replacement) = self.encrypted_literals.remove(&NodeKey::new(value)) { - let target_node = target_node.downcast_mut::().unwrap(); - *target_node = replacement; + let target_node = target_node.downcast_mut::().unwrap(); + *target_node = make_row_expression(replacement); return Ok(true); } } @@ -38,7 +41,7 @@ impl<'ast> TransformationRule<'ast> for ReplacePlaintextEqlLiterals<'ast> { } fn would_edit(&mut self, node_path: &NodePath<'ast>, _target_node: &N) -> bool { - if let Some((value,)) = node_path.last_1_as::() { + if let Some((Expr::Value(value),)) = node_path.last_1_as::() { return self.encrypted_literals.contains_key(&NodeKey::new(value)); } false @@ -55,3 +58,25 @@ impl<'ast> TransformationRule<'ast> for ReplacePlaintextEqlLiterals<'ast> { } } } + +fn make_row_expression(replacement: Value) -> Expr { + Expr::Function(Function { + name: ObjectName(vec![Ident::new("ROW")]), + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + args: FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + clauses: vec![], + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Cast { + kind: CastKind::DoubleColon, + expr: Box::new(Expr::Value(replacement)), + data_type: DataType::JSONB, + format: None, + }))], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }) +} diff --git a/packages/eql-mapper/src/type_checked_statement.rs b/packages/eql-mapper/src/type_checked_statement.rs index bb66aef8..43012b7d 100644 --- a/packages/eql-mapper/src/type_checked_statement.rs +++ b/packages/eql-mapper/src/type_checked_statement.rs @@ -113,11 +113,7 @@ impl<'ast> TypeCheckedStatement<'ast> { } for (key, _) in encrypted_literals.iter() { - if !self - .literals - .iter() - .any(|(_, node)| &node.as_node_key() == key) - { + if !self.literal_exists_for_node_key(*key) { return Err(EqlMapperError::Transform(String::from( "encrypted literals refers to a literal node which is not present in the SQL statement" ))); @@ -126,6 +122,12 @@ impl<'ast> TypeCheckedStatement<'ast> { Ok(()) } + fn literal_exists_for_node_key(&self, key: NodeKey<'ast>) -> bool { + self.literals + .iter() + .any(|(_, node)| node.as_node_key() == key) + } + fn count_not_null_literals(&self) -> usize { self.literals .iter() diff --git a/proxy.Dockerfile b/proxy.Dockerfile index 27a9d431..02cea0e5 100644 --- a/proxy.Dockerfile +++ b/proxy.Dockerfile @@ -10,7 +10,7 @@ COPY cipherstash-proxy /usr/local/bin/cipherstash-proxy COPY docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh # Copy EQL install scripts -COPY cipherstash-eql.sql /opt/cipherstash-eql.sql +COPY cipherstash-encrypt.sql /opt/cipherstash-eql.sql # Copy example schema COPY docs/getting-started/schema-example.sql /opt/schema-example.sql diff --git a/tests/benchmark/sql/benchmark-schema.sql b/tests/benchmark/sql/benchmark-schema.sql index fecdf7bf..e975f2e8 100644 --- a/tests/benchmark/sql/benchmark-schema.sql +++ b/tests/benchmark/sql/benchmark-schema.sql @@ -1,4 +1,4 @@ -TRUNCATE TABLE cs_configuration_v1; +TRUNCATE TABLE public.eql_v1_configuration; DROP TABLE IF EXISTS benchmark_plaintext; CREATE TABLE benchmark_plaintext ( diff --git a/tests/sql/schema-uninstall.sql b/tests/sql/schema-uninstall.sql index 3c34ba76..ae6630cf 100644 --- a/tests/sql/schema-uninstall.sql +++ b/tests/sql/schema-uninstall.sql @@ -1,4 +1,4 @@ -DROP TABLE IF EXISTS cs_configuration_v1; +DROP TABLE IF EXISTS public.eql_v1_configuration; -- Regular old table DROP TABLE IF EXISTS plaintext; diff --git a/tests/sql/schema.sql b/tests/sql/schema.sql index c1398811..57f5e29f 100644 --- a/tests/sql/schema.sql +++ b/tests/sql/schema.sql @@ -1,4 +1,4 @@ -TRUNCATE TABLE cs_configuration_v1; +TRUNCATE TABLE public.eql_v1_configuration; -- Regular old table DROP TABLE IF EXISTS plaintext; @@ -13,95 +13,95 @@ DROP TABLE IF EXISTS encrypted; CREATE TABLE encrypted ( id bigint, plaintext text, - encrypted_text cs_encrypted_v1, - encrypted_bool cs_encrypted_v1, - encrypted_int2 cs_encrypted_v1, - encrypted_int4 cs_encrypted_v1, - encrypted_int8 cs_encrypted_v1, - encrypted_float8 cs_encrypted_v1, - encrypted_date cs_encrypted_v1, - encrypted_jsonb cs_encrypted_v1, + encrypted_text eql_v1_encrypted, + encrypted_bool eql_v1_encrypted, + encrypted_int2 eql_v1_encrypted, + encrypted_int4 eql_v1_encrypted, + encrypted_int8 eql_v1_encrypted, + encrypted_float8 eql_v1_encrypted, + encrypted_date eql_v1_encrypted, + encrypted_jsonb eql_v1_encrypted, PRIMARY KEY(id) ); DROP TABLE IF EXISTS unconfigured; CREATE TABLE unconfigured ( id bigint, - encrypted_unconfigured cs_encrypted_v1, + encrypted_unconfigured eql_v1_encrypted, PRIMARY KEY(id) ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_text', 'unique', 'text' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_text', 'match', 'text' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_text', 'ore', 'text' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_bool', 'unique', 'boolean' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_bool', 'ore', 'boolean' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_int2', 'unique', 'small_int' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_int2', 'ore', 'small_int' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_int4', 'unique', 'int' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_int4', 'ore', 'int' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_int8', 'unique', 'big_int' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_int8', 'ore', @@ -109,35 +109,35 @@ SELECT cs_add_index_v1( ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_float8', 'unique', 'double' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_float8', 'ore', 'double' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_date', 'unique', 'date' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_date', 'ore', 'date' ); -SELECT cs_add_index_v1( +SELECT eql_v1.add_index( 'encrypted', 'encrypted_jsonb', 'ste_vec', @@ -145,5 +145,5 @@ SELECT cs_add_index_v1( '{"prefix": "encrypted/encrypted_jsonb"}' ); -SELECT cs_encrypt_v1(); -SELECT cs_activate_v1(); +SELECT eql_v1.encrypt(); +SELECT eql_v1.activate(); diff --git a/tests/tasks/test/integration/psql-passthrough.sh b/tests/tasks/test/integration/psql-passthrough.sh index e07fd77d..c9319539 100755 --- a/tests/tasks/test/integration/psql-passthrough.sh +++ b/tests/tasks/test/integration/psql-passthrough.sh @@ -17,10 +17,10 @@ EOF # Confirm that there is indeed no config set +e -OUTPUT="$(docker exec -i postgres${CONTAINER_SUFFIX} psql 'postgresql://cipherstash:password@proxy:6432/cipherstash?sslmode=disable' --command 'SELECT * FROM cs_configuration_v1' 2>&1)" +OUTPUT="$(docker exec -i postgres${CONTAINER_SUFFIX} psql 'postgresql://cipherstash:password@proxy:6432/cipherstash?sslmode=disable' --command 'SELECT * FROM eql_v1_configuration' 2>&1)" retval=$? -if echo ${OUTPUT} | grep -v 'relation "cs_configuration_v1" does not exist'; then - echo "error: did not see string in output: \"relation "cs_configuration_v1" does not exist\"" +if echo ${OUTPUT} | grep -v 'relation "eql_v1_configuration" does not exist'; then + echo "error: did not see string in output: \"relation "eql_v1_configuration" does not exist\"" exit 1 fi