From 542c921dc0c0410c99aca06daad08795906222da Mon Sep 17 00:00:00 2001 From: Toby Hede Date: Thu, 24 Apr 2025 11:52:42 +1000 Subject: [PATCH 1/7] chore: generate.rs module for creating EQL test data --- .../cipherstash-proxy-integration/Cargo.toml | 10 +- .../src/generate.rs | 265 ++++++++++++++++++ .../cipherstash-proxy-integration/src/lib.rs | 1 + packages/cipherstash-proxy/Cargo.toml | 8 +- 4 files changed, 281 insertions(+), 3 deletions(-) create mode 100644 packages/cipherstash-proxy-integration/src/generate.rs diff --git a/packages/cipherstash-proxy-integration/Cargo.toml b/packages/cipherstash-proxy-integration/Cargo.toml index 546e1ff8..974f72ef 100644 --- a/packages/cipherstash-proxy-integration/Cargo.toml +++ b/packages/cipherstash-proxy-integration/Cargo.toml @@ -24,5 +24,13 @@ tracing-subscriber = { workspace = true } webpki-roots = "0.26.7" [dev-dependencies] +# cipherstash-client = { version = "0.18.0-pre.1", features = ["tokio"] } +cipherstash-client = { path = "/Users/tobyhede/src/cipherstash-suite/packages/cipherstash-client", features = [ + "tokio", +] } +# cipherstash-config = "0.2.3" +cipherstash-config = { path = "/Users/tobyhede/src/cipherstash-suite/packages/cipherstash-config" } clap = "4.5.32" -fake = { version = "4", features = ["derive"] } +fake = { version = "4", features = ["chrono", "derive"] } +hex = "0.4.3" +uuid = { version = "1.11.0", features = ["serde", "v4"] } diff --git a/packages/cipherstash-proxy-integration/src/generate.rs b/packages/cipherstash-proxy-integration/src/generate.rs new file mode 100644 index 00000000..146b539f --- /dev/null +++ b/packages/cipherstash-proxy-integration/src/generate.rs @@ -0,0 +1,265 @@ +#[cfg(test)] +mod tests { + use crate::common::{clear, connect_with_tls, id, trace, PROXY}; + use cipherstash_client::config::EnvSource; + use cipherstash_client::credentials::auto_refresh::AutoRefresh; + use cipherstash_client::ejsonpath::Selector; + use cipherstash_client::encryption::{ + Encrypted, EncryptedEntry, EncryptedSteVecTerm, JsonIndexer, JsonIndexerOptions, OreTerm, + Plaintext, PlaintextTarget, QueryBuilder, ReferencedPendingPipeline, + }; + use cipherstash_client::{ + encryption::{ScopedCipher, SteVec}, + zerokms::{encrypted_record, EncryptedRecord}, + }; + use cipherstash_client::{ConsoleConfig, CtsConfig, ZeroKMSConfig}; + use cipherstash_config::column::{Index, IndexType}; + use cipherstash_config::{ColumnConfig, ColumnMode, ColumnType}; + use cipherstash_proxy::Identifier; + use rustls::unbuffered::EncodeError; + use serde::{Deserialize, Serialize}; + use std::sync::Arc; + use tracing::info; + use uuid::Uuid; + + pub mod option_mp_base85 { + use cipherstash_client::zerokms::encrypted_record::formats::mp_base85; + use cipherstash_client::zerokms::EncryptedRecord; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize( + value: &Option, + serializer: S, + ) -> Result + where + S: Serializer, + { + match value { + Some(record) => mp_base85::serialize(record, serializer), + None => serializer.serialize_none(), + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let result = Option::::deserialize(deserializer)?; + Ok(result) + } + } + + #[derive(Debug, Deserialize, Serialize)] + + pub struct EqlEncrypted { + #[serde(rename = "c", with = "option_mp_base85")] + ciphertext: Option, + #[serde(rename = "i")] + identifier: Identifier, + #[serde(rename = "v")] + version: u16, + + #[serde(rename = "o")] + ore_index: Option>, + #[serde(rename = "m")] + match_index: Option>, + #[serde(rename = "u")] + unique_index: Option, + + #[serde(rename = "s")] + selector: Option, + + #[serde(rename = "b")] + blake3_index: Option, + + #[serde(rename = "ocf")] + ore_cclw_fixed_index: Option, + #[serde(rename = "ocv")] + ore_cclw_var_index: Option, + + #[serde(rename = "sv")] + ste_vec_index: Option>, + } + + #[derive(Debug, Deserialize, Serialize)] + pub struct EqlSteVecEncrypted { + #[serde(rename = "c", with = "option_mp_base85")] + ciphertext: Option, + + #[serde(rename = "s")] + selector: Option, + #[serde(rename = "b")] + blake3_index: Option, + #[serde(rename = "ocf")] + ore_cclw_fixed_index: Option, + #[serde(rename = "ocv")] + ore_cclw_var_index: Option, + } + + impl EqlEncrypted { + pub fn ste_vec(ste_vec_index: Vec) -> Self { + Self { + ste_vec_index: Some(ste_vec_index), + ciphertext: None, + identifier: Identifier { + table: "blah".to_string(), + column: "vtha".to_string(), + }, + version: 1, + ore_index: None, + match_index: None, + unique_index: None, + selector: None, + ore_cclw_fixed_index: None, + ore_cclw_var_index: None, + blake3_index: None, + } + } + } + impl EqlSteVecEncrypted { + pub fn ste_vec_element(selector: String, record: EncryptedRecord) -> Self { + Self { + ciphertext: Some(record), + selector: Some(selector), + ore_cclw_fixed_index: None, + ore_cclw_var_index: None, + blake3_index: None, + } + } + } + + #[tokio::test] + async fn generate_ste_vec() { + trace(); + + // clear().await; + // let client = connect_with_tls(PROXY).await; + + let console_config = ConsoleConfig::builder().with_env().build().unwrap(); + let cts_config = CtsConfig::builder().with_env().build().unwrap(); + let zerokms_config = ZeroKMSConfig::builder() + .add_source(EnvSource::default()) + .console_config(&console_config) + .cts_config(&cts_config) + .build_with_client_key() + .unwrap(); + let zerokms_client = zerokms_config + .create_client_with_credentials(AutoRefresh::new(zerokms_config.credentials())); + + let dataset_id = Uuid::parse_str("295504329cb045c398dc464c52a287a1").unwrap(); + + let cipher = Arc::new( + ScopedCipher::init(Arc::new(zerokms_client), Some(dataset_id)) + .await + .unwrap(), + ); + + let prefix = "prefix".to_string(); + + let column_config = ColumnConfig::build("column_name".to_string()) + .casts_as(ColumnType::JsonB) + .add_index(Index::new(IndexType::SteVec { + prefix: prefix.to_owned(), + })); + + // let mut value = + // serde_json::from_str::("{\"hello\": \"one\", \"n\": 10}").unwrap(); + + // let mut value = + // serde_json::from_str::("{\"hello\": \"two\", \"n\": 20}").unwrap(); + + let mut value = + serde_json::from_str::("{\"hello\": \"two\", \"n\": 30}").unwrap(); + + // let mut value = + // serde_json::from_str::("{\"hello\": \"world\", \"n\": 42}").unwrap(); + + // let mut value = + // serde_json::from_str::("{\"hello\": \"world\", \"n\": 42}").unwrap(); + + // let mut value = + // serde_json::from_str::("{\"blah\": { \"vtha\": 42 }}").unwrap(); + + let plaintext = Plaintext::JsonB(Some(value)); + + let idx = 0; + + let mut pipeline = ReferencedPendingPipeline::new(cipher.clone()); + let encryptable = PlaintextTarget::new(plaintext, column_config); + pipeline + .add_with_ref::(encryptable, idx) + .unwrap(); + + let mut encrypteds = vec![]; + + let mut result = pipeline.encrypt(None).await.unwrap(); + if let Some(Encrypted::SteVec(ste_vec)) = result.remove(idx) { + for entry in ste_vec { + let selector = hex::encode(entry.0 .0); + let term = entry.1; + let record = entry.2; + + let mut e = EqlSteVecEncrypted::ste_vec_element(selector, record); + + match term { + EncryptedSteVecTerm::Mac(items) => { + e.blake3_index = Some(hex::encode(&items)); + } + EncryptedSteVecTerm::OreFixed(o) => { + e.ore_cclw_fixed_index = Some(hex::encode(o.bytes)); + } + EncryptedSteVecTerm::OreVariable(o) => { + e.ore_cclw_var_index = Some(hex::encode(o.bytes)); + } + } + + encrypteds.push(e); + } + // info!("{:?}" = encrypteds); + } + + info!("---------------------------------------------"); + + let e = EqlEncrypted::ste_vec(encrypteds); + info!("{:?}" = ?e); + + let json = serde_json::to_value(e).unwrap(); + info!("{}", json); + + let indexer = JsonIndexer::new(JsonIndexerOptions { prefix }); + + info!("---------------------------------------------"); + + // Path + // let path: String = "$.blah.vtha".to_string(); + // let selector = Selector::parse(&path).unwrap(); + // let selector = indexer.generate_selector(selector, cipher.index_key()); + // let selector = hex::encode(selector.0); + // info!("{}", selector); + + // Comparison + let n = 30; + let term = OreTerm::Number(n); + + let term = indexer.generate_term(term, cipher.index_key()).unwrap(); + + match term { + EncryptedSteVecTerm::Mac(items) => todo!(), + EncryptedSteVecTerm::OreFixed(ore_cllw8_v1) => { + let term = hex::encode(ore_cllw8_v1.bytes); + info!("{n}: {term}"); + } + EncryptedSteVecTerm::OreVariable(ore_cllw8_variable_v1) => todo!(), + } + + // if let Some(ste_vec_index) = e.ste_vec_index { + // for e in ste_vec_index { + // info!("{}", e); + // if let Some(ct) = e.ciphertext { + // let decrypted = cipher.decrypt(encrypted).await?; + // info!("{}", decrypted); + // } + // } + // } + } +} diff --git a/packages/cipherstash-proxy-integration/src/lib.rs b/packages/cipherstash-proxy-integration/src/lib.rs index dcff8072..db8e84f8 100644 --- a/packages/cipherstash-proxy-integration/src/lib.rs +++ b/packages/cipherstash-proxy-integration/src/lib.rs @@ -1,6 +1,7 @@ mod common; mod empty_result; mod extended_protocol_error_messages; +mod generate; mod map_concat; mod map_literals; mod map_match_index; diff --git a/packages/cipherstash-proxy/Cargo.toml b/packages/cipherstash-proxy/Cargo.toml index d63a2513..a5c47ea9 100644 --- a/packages/cipherstash-proxy/Cargo.toml +++ b/packages/cipherstash-proxy/Cargo.toml @@ -8,8 +8,12 @@ bigdecimal = { version = "0.4.6", features = ["serde-json"] } arc-swap = "1.7.1" bytes = { version = "1.9", default-features = false } chrono = { version = "0.4.39", features = ["clock"] } -cipherstash-client = { version = "0.18.0-pre.1", features = ["tokio"] } -cipherstash-config = "0.2.3" +# cipherstash-client = { version = "0.18.0-pre.1", features = ["tokio"] } +cipherstash-client = { path = "/Users/tobyhede/src/cipherstash-suite/packages/cipherstash-client", features = [ + "tokio", +] } +# cipherstash-config = "0.2.3" +cipherstash-config = { path = "/Users/tobyhede/src/cipherstash-suite/packages/cipherstash-config" } clap = { version = "4.5.31", features = ["derive", "env"] } config = { version = "0.15", features = [ "async", From 07b8869f93db82cde83dec8b9fa53ceb897c4bc8 Mon Sep 17 00:00:00 2001 From: James Sadler Date: Wed, 30 Apr 2025 14:08:50 +1000 Subject: [PATCH 2/7] chore: formatting --- packages/cipherstash-proxy-integration/src/generate.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/cipherstash-proxy-integration/src/generate.rs b/packages/cipherstash-proxy-integration/src/generate.rs index 146b539f..49e4bf06 100644 --- a/packages/cipherstash-proxy-integration/src/generate.rs +++ b/packages/cipherstash-proxy-integration/src/generate.rs @@ -50,7 +50,6 @@ mod tests { } #[derive(Debug, Deserialize, Serialize)] - pub struct EqlEncrypted { #[serde(rename = "c", with = "option_mp_base85")] ciphertext: Option, From 3541b2ab292112724e245e50201f51a5b15c5962 Mon Sep 17 00:00:00 2001 From: James Sadler Date: Wed, 30 Apr 2025 15:18:34 +1000 Subject: [PATCH 3/7] feat: encrypted JSON should use the new EQL schema --- Cargo.lock | 210 ++++++++++++++++-- .../cipherstash-proxy-integration/Cargo.toml | 4 +- packages/cipherstash-proxy/Cargo.toml | 4 +- packages/cipherstash-proxy/src/encrypt/mod.rs | 154 ++++++++----- packages/cipherstash-proxy/src/eql/mod.rs | 132 +++++------ packages/cipherstash-proxy/src/lib.rs | 2 +- .../src/postgresql/backend.rs | 4 +- .../src/postgresql/frontend.rs | 8 +- .../src/postgresql/messages/bind.rs | 2 +- .../src/postgresql/messages/data_row.rs | 4 +- .../src/inference/infer_type_impls/expr.rs | 2 +- packages/eql-mapper/src/lib.rs | 82 +++++++ packages/eql-mapper/src/model/type_system.rs | 7 + packages/eql-mapper/src/test_helpers.rs | 36 ++- .../eql-mapper/src/type_checked_statement.rs | 47 +++- 15 files changed, 540 insertions(+), 158 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b2e0ee88..21164890 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -181,6 +181,7 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" dependencies = [ + "serde", "zeroize", ] @@ -287,6 +288,61 @@ dependencies = [ "fs_extra", ] +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -317,6 +373,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" +[[package]] +name = "base32" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "022dfe9eb35f19ebbcb51e0b40a5ab759f46ad60cadf7297e0bd085afb50e076" + [[package]] name = "base64" version = "0.22.1" @@ -589,8 +651,6 @@ dependencies = [ [[package]] name = "cipherstash-client" version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f099b1db6cf37b0ca36e9c8e0c2dade20f2035804e225f52475d44e750dd5dd5" dependencies = [ "aes-gcm-siv", "anyhow", @@ -605,6 +665,7 @@ dependencies = [ "cipherstash-config", "cipherstash-core", "cllw-ore", + "cts-common", "derive_more", "dirs", "futures", @@ -620,7 +681,7 @@ dependencies = [ "percent-encoding", "rand 0.8.5", "rand_chacha 0.3.1", - "recipher", + "recipher 0.1.3", "reqwest", "reqwest-middleware", "reqwest-retry", @@ -648,8 +709,6 @@ dependencies = [ [[package]] name = "cipherstash-config" version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30104045751da6e528e83804f4b22d0cddcb27aacce0e1c79604872ddb076bbf" dependencies = [ "serde", "thiserror 1.0.69", @@ -658,8 +717,6 @@ dependencies = [ [[package]] name = "cipherstash-core" version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd56dfac0a35146968ef6696fb822b22f70a664a8739874385876d5452844b7a" dependencies = [ "hmac", "lazy_static", @@ -694,7 +751,7 @@ dependencies = [ "postgres-protocol", "postgres-types", "rand 0.9.0", - "recipher", + "recipher 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", "regex", "ring", "rust_decimal", @@ -724,11 +781,14 @@ name = "cipherstash-proxy-integration" version = "0.1.0" dependencies = [ "chrono", + "cipherstash-client", + "cipherstash-config", "cipherstash-proxy", "clap", "fake 4.2.0", + "hex", "rand 0.9.0", - "recipher", + "recipher 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", "rustls", "serde", "serde_json", @@ -739,6 +799,7 @@ dependencies = [ "tokio-rustls", "tracing", "tracing-subscriber", + "uuid", "webpki-roots", ] @@ -796,8 +857,6 @@ checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "cllw-ore" version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1b01c26e11101044f85802e31d842483ef983a890c03472d9489f6969cf865a" dependencies = [ "bit-vec", "bitvec", @@ -953,6 +1012,25 @@ dependencies = [ "cipher 0.4.4", ] +[[package]] +name = "cts-common" +version = "0.1.0" +dependencies = [ + "arrayvec", + "axum", + "base32", + "diesel", + "fake 3.1.0", + "http", + "miette", + "rand 0.8.5", + "regex", + "serde", + "thiserror 1.0.69", + "url", + "vitaminc", +] + [[package]] name = "darling" version = "0.20.10" @@ -973,6 +1051,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", + "strsim", "syn 2.0.100", ] @@ -1075,6 +1154,39 @@ version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc55fe0d1f6c107595572ec8b107c0999bb1a2e0b75e37429a4fb0d6474a0e7d" +[[package]] +name = "diesel" +version = "2.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff3e1edb1f37b4953dd5176916347289ed43d7119cc2e6c7c3f7849ff44ea506" +dependencies = [ + "chrono", + "diesel_derives", + "uuid", +] + +[[package]] +name = "diesel_derives" +version = "2.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68d4216021b3ea446fd2047f5c8f8fe6e98af34508a254a01e4d6bc1e844f84d" +dependencies = [ + "diesel_table_macro_syntax", + "dsl_auto_type", + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "diesel_table_macro_syntax" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25" +dependencies = [ + "syn 2.0.100", +] + [[package]] name = "diff" version = "0.1.13" @@ -1123,6 +1235,20 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "dsl_auto_type" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139ae9aca7527f85f26dd76483eb38533fd84bd571065da1739656ef71c5ff5b" +dependencies = [ + "darling", + "either", + "heck", + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "dummy" version = "0.8.0" @@ -1135,6 +1261,18 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "dummy" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abcba80bdf851db5616e27ff869399468e2d339d7c6480f5887681e6bdfc2186" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "dummy" version = "0.11.0" @@ -1223,12 +1361,24 @@ dependencies = [ "uuid", ] +[[package]] +name = "fake" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef603df4ba9adbca6a332db7da6f614f21eafefbaf8e087844e452fdec152d0" +dependencies = [ + "deunicode", + "dummy 0.9.2", + "rand 0.8.5", +] + [[package]] name = "fake" version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b591050272097cc85b2f3c1cc4817ba4560057d10fcae6f7339f1cf622da0a0f" dependencies = [ + "chrono", "deunicode", "dummy 0.11.0", "rand 0.9.0", @@ -2034,6 +2184,12 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "md-5" version = "0.10.6" @@ -2788,6 +2944,25 @@ dependencies = [ "bitflags 2.9.0", ] +[[package]] +name = "recipher" +version = "0.1.3" +dependencies = [ + "aes", + "async-trait", + "cmac", + "hex", + "hex-literal", + "opaque-debug", + "rand 0.8.5", + "rand_chacha 0.3.1", + "serde", + "serde_cbor", + "sha2", + "thiserror 1.0.69", + "zeroize", +] + [[package]] name = "recipher" version = "0.1.3" @@ -3328,6 +3503,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.8" @@ -3905,6 +4090,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -4848,8 +5034,6 @@ dependencies = [ [[package]] name = "zerokms-protocol" version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01a9d0d8103cfa862b451f2c35144301df25a233f7fae041666b890a1578c3b1" dependencies = [ "async-trait", "base64", diff --git a/packages/cipherstash-proxy-integration/Cargo.toml b/packages/cipherstash-proxy-integration/Cargo.toml index 974f72ef..aa79a2e9 100644 --- a/packages/cipherstash-proxy-integration/Cargo.toml +++ b/packages/cipherstash-proxy-integration/Cargo.toml @@ -25,11 +25,11 @@ webpki-roots = "0.26.7" [dev-dependencies] # cipherstash-client = { version = "0.18.0-pre.1", features = ["tokio"] } -cipherstash-client = { path = "/Users/tobyhede/src/cipherstash-suite/packages/cipherstash-client", features = [ +cipherstash-client = { path = "../../../cipherstash-suite/packages/cipherstash-client", features = [ "tokio", ] } # cipherstash-config = "0.2.3" -cipherstash-config = { path = "/Users/tobyhede/src/cipherstash-suite/packages/cipherstash-config" } +cipherstash-config = { path = "../../../cipherstash-suite/packages/cipherstash-config" } clap = "4.5.32" fake = { version = "4", features = ["chrono", "derive"] } hex = "0.4.3" diff --git a/packages/cipherstash-proxy/Cargo.toml b/packages/cipherstash-proxy/Cargo.toml index a5c47ea9..09bc0399 100644 --- a/packages/cipherstash-proxy/Cargo.toml +++ b/packages/cipherstash-proxy/Cargo.toml @@ -9,11 +9,11 @@ arc-swap = "1.7.1" bytes = { version = "1.9", default-features = false } chrono = { version = "0.4.39", features = ["clock"] } # cipherstash-client = { version = "0.18.0-pre.1", features = ["tokio"] } -cipherstash-client = { path = "/Users/tobyhede/src/cipherstash-suite/packages/cipherstash-client", features = [ +cipherstash-client = { path = "../../../cipherstash-suite/packages/cipherstash-client", features = [ "tokio", ] } # cipherstash-config = "0.2.3" -cipherstash-config = { path = "/Users/tobyhede/src/cipherstash-suite/packages/cipherstash-config" } +cipherstash-config = { path = "../../../cipherstash-suite/packages/cipherstash-config" } clap = { version = "4.5.31", features = ["derive", "env"] } config = { version = "0.15", features = [ "async", diff --git a/packages/cipherstash-proxy/src/encrypt/mod.rs b/packages/cipherstash-proxy/src/encrypt/mod.rs index faac044e..30bf52c9 100644 --- a/packages/cipherstash-proxy/src/encrypt/mod.rs +++ b/packages/cipherstash-proxy/src/encrypt/mod.rs @@ -3,7 +3,8 @@ mod schema; use crate::{ config::TandemConfig, - connect, eql, + connect, + eql::{self, EqlEncryptedBody, EqlEncryptedIndexes}, error::{EncryptError, Error}, log::ENCRYPT, postgresql::Column, @@ -13,10 +14,9 @@ use cipherstash_client::{ config::EnvSource, credentials::{auto_refresh::AutoRefresh, ServiceCredentials}, encryption::{ - self, Encrypted, EncryptionError, IndexTerm, Plaintext, PlaintextTarget, - ReferencedPendingPipeline, + self, Encrypted, EncryptedEntry, EncryptedSteVecTerm, IndexTerm, Plaintext, + PlaintextTarget, ReferencedPendingPipeline, }, - zerokms::EncryptedRecord, ConsoleConfig, CtsConfig, ZeroKMSConfig, }; use cipherstash_config::ColumnConfig; @@ -88,7 +88,7 @@ impl Encrypt { &self, plaintexts: Vec>, columns: &[Option], - ) -> Result>, Error> { + ) -> Result>, Error> { let mut pipeline = ReferencedPendingPipeline::new(self.cipher.clone()); for (idx, item) in plaintexts.into_iter().zip(columns.iter()).enumerate() { @@ -141,22 +141,17 @@ impl Encrypt { /// pub async fn decrypt( &self, - ciphertexts: Vec>, + ciphertexts: Vec>, ) -> Result>, Error> { // Create a mutable vector to hold the decrypted results let mut results = vec![None; ciphertexts.len()]; // Collect the index and ciphertext details for every Some(ciphertext) - let (indices, encrypted) = ciphertexts + let (indices, encrypted): (Vec<_>, Vec<_>) = ciphertexts .into_iter() .enumerate() - .filter_map(|(idx, opt)| { - opt.map(|ct| { - eql_encrypted_to_encrypted_record(ct) - .map(|encrypted_record| (idx, encrypted_record)) - }) - }) - .collect::, Vec<_>), _>>()?; + .filter_map(|(idx, eql)| Some((idx, eql?.body.ciphertext))) + .collect::<_>(); // Decrypt the ciphertexts let decrypted = self.cipher.decrypt(encrypted).await?; @@ -236,56 +231,120 @@ async fn init_cipher(config: &TandemConfig) -> Result { fn to_eql_encrypted( encrypted: Encrypted, identifier: &Identifier, -) -> Result { +) -> Result { debug!(target: ENCRYPT, msg = "Encrypted to EQL", ?identifier); match encrypted { Encrypted::Record(ciphertext, terms) => { - struct Indexes { - match_index: Option>, - ore_index: Option>, - unique_index: Option, - } - - let mut indexes = Indexes { - match_index: None, - ore_index: None, - unique_index: None, - }; + let mut match_index: Option> = None; + let mut ore_index: Option> = None; + let mut unique_index: Option = None; + let mut blake3_index: Option = None; + let mut ore_cclw_fixed_index: Option = None; + let mut ore_cclw_var_index: Option = None; + let mut selector: Option = None; for index_term in terms { match index_term { IndexTerm::Binary(bytes) => { - indexes.unique_index = Some(format_index_term_binary(&bytes)) + unique_index = Some(format_index_term_binary(&bytes)) } - IndexTerm::BitMap(inner) => indexes.match_index = Some(inner), - IndexTerm::OreArray(vec_of_bytes) => { - indexes.ore_index = Some(format_index_term_ore_array(&vec_of_bytes)); + IndexTerm::BitMap(inner) => match_index = Some(inner), + IndexTerm::OreArray(bytes) => { + ore_index = Some(format_index_term_ore_array(&bytes)); } IndexTerm::OreFull(bytes) => { - indexes.ore_index = Some(format_index_term_ore(&bytes)); + ore_index = Some(format_index_term_ore(&bytes)); } IndexTerm::OreLeft(bytes) => { - indexes.ore_index = Some(format_index_term_ore(&bytes)); + ore_index = Some(format_index_term_ore(&bytes)); } + IndexTerm::BinaryVec(_) => todo!(), + IndexTerm::SteVecSelector(s) => { + selector = Some(hex::encode(s.0)); + } + IndexTerm::SteVecTerm(ste_vec_term) => match ste_vec_term { + EncryptedSteVecTerm::Mac(bytes) => blake3_index = Some(hex::encode(bytes)), + EncryptedSteVecTerm::OreFixed(ore) => { + ore_cclw_fixed_index = Some(hex::encode(ore.bytes)) + } + EncryptedSteVecTerm::OreVariable(ore) => { + ore_cclw_var_index = Some(hex::encode(ore.bytes)) + } + }, + IndexTerm::SteQueryVec(query) => {} // TODO: what do we do here? IndexTerm::Null => {} - _ => return Err(EncryptError::UnknownIndexTerm(identifier.to_owned()).into()), }; } - Ok(eql::Encrypted::Ciphertext { - ciphertext, + Ok(eql::EqlEncrypted { identifier: identifier.to_owned(), - match_index: indexes.match_index, - ore_index: indexes.ore_index, - unique_index: indexes.unique_index, version: 1, + body: EqlEncryptedBody { + ciphertext, + indexes: EqlEncryptedIndexes { + match_index, + ore_index, + unique_index, + blake3_index, + ore_cclw_fixed_index, + ore_cclw_var_index, + selector, + ste_vec_index: None, + }, + }, + }) + } + Encrypted::SteVec(ste_vec) => { + let ciphertext = ste_vec.root_ciphertext()?.clone(); + + let ste_vec_index: Vec = ste_vec + .into_iter() + .map(|EncryptedEntry(selector, term, ciphertext)| { + let indexes = match term { + EncryptedSteVecTerm::Mac(bytes) => EqlEncryptedIndexes { + selector: Some(hex::encode(selector.0)), + blake3_index: Some(hex::encode(bytes)), + ..Default::default() + }, + EncryptedSteVecTerm::OreFixed(ore) => EqlEncryptedIndexes { + selector: Some(hex::encode(selector.0)), + ore_cclw_fixed_index: Some(hex::encode(ore.bytes)), + ..Default::default() + }, + EncryptedSteVecTerm::OreVariable(ore) => EqlEncryptedIndexes { + selector: Some(hex::encode(selector.0)), + ore_cclw_var_index: Some(hex::encode(ore.bytes)), + ..Default::default() + }, + }; + + eql::EqlEncryptedBody { + ciphertext, + indexes, + } + }) + .collect(); + + // FIXME: I'm unsure if I've handled the root ciphertext correctly + // The way it's implemented right now is that it will be repeated one in the ste_vec_index. + Ok(eql::EqlEncrypted { + identifier: identifier.to_owned(), + version: 1, + body: EqlEncryptedBody { + ciphertext: ciphertext.clone(), + indexes: EqlEncryptedIndexes { + match_index: None, + ore_index: None, + unique_index: None, + blake3_index: None, + ore_cclw_fixed_index: None, + ore_cclw_var_index: None, + selector: None, + ste_vec_index: Some(ste_vec_index), + }, + }, }) } - Encrypted::SteVec(ste_vec_index) => Ok(eql::Encrypted::SteVec { - identifier: identifier.to_owned(), - ste_vec_index, - version: 1, - }), } } @@ -314,15 +373,6 @@ fn format_index_term_ore(bytes: &Vec) -> Vec { vec![format_index_term_ore_bytea(bytes)] } -fn eql_encrypted_to_encrypted_record( - eql_encrypted: eql::Encrypted, -) -> Result { - match eql_encrypted { - eql::Encrypted::Ciphertext { ciphertext, .. } => Ok(ciphertext), - eql::Encrypted::SteVec { ste_vec_index, .. } => ste_vec_index.into_root_ciphertext(), - } -} - fn plaintext_type_name(pt: Plaintext) -> String { match pt { Plaintext::BigInt(_) => "BigInt".to_string(), diff --git a/packages/cipherstash-proxy/src/eql/mod.rs b/packages/cipherstash-proxy/src/eql/mod.rs index 69266310..03ec0141 100644 --- a/packages/cipherstash-proxy/src/eql/mod.rs +++ b/packages/cipherstash-proxy/src/eql/mod.rs @@ -1,7 +1,4 @@ -use cipherstash_client::{ - encryption::SteVec, - zerokms::{encrypted_record, EncryptedRecord}, -}; +use cipherstash_client::zerokms::{encrypted_record, EncryptedRecord}; use serde::{Deserialize, Serialize}; use sqltk::parser::ast::Ident; @@ -16,6 +13,7 @@ pub struct Plaintext { #[serde(rename = "q")] pub for_query: Option, } + #[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct Identifier { #[serde(rename = "t")] @@ -65,54 +63,57 @@ pub enum ForQuery { } #[derive(Debug, Deserialize, Serialize)] -#[serde(tag = "k")] -pub enum Encrypted { - #[serde(rename = "ct")] - Ciphertext { - #[serde(rename = "c", with = "encrypted_record::formats::mp_base85")] - ciphertext: EncryptedRecord, - #[serde(rename = "o")] - ore_index: Option>, - #[serde(rename = "m")] - match_index: Option>, - #[serde(rename = "u")] - unique_index: Option, - #[serde(rename = "i")] - identifier: Identifier, - #[serde(rename = "v")] - version: u16, - }, - #[serde(rename = "sv")] - SteVec { - #[serde(rename = "sv")] - ste_vec_index: SteVec<16>, - #[serde(rename = "i")] - identifier: Identifier, - #[serde(rename = "v")] - version: u16, - }, +pub struct EqlEncrypted { + #[serde(rename = "i")] + pub(crate) identifier: Identifier, + #[serde(rename = "v")] + pub(crate) version: u16, + + #[serde(flatten)] + pub(crate) body: EqlEncryptedBody, } -// fn ident_de<'de, D>(deserializer: D) -> Result -// where -// D: serde::Deserializer<'de>, -// { -// let s = String::deserialize(deserializer)?; -// Ok(Ident::with_quote('"', s)) -// } - -// fn ident_se(ident: &Ident, serializer: S) -> Result -// where -// S: Serializer, -// { -// let s = ident.to_string(); -// serializer.serialize_str(&s) -// } +#[derive(Debug, Deserialize, Serialize)] +pub struct EqlEncryptedBody { + #[serde(rename = "c", with = "encrypted_record::formats::mp_base85")] + pub(crate) ciphertext: EncryptedRecord, + + #[serde(flatten)] + pub(crate) indexes: EqlEncryptedIndexes, +} + +#[derive(Debug, Deserialize, Serialize, Default)] +pub struct EqlEncryptedIndexes { + #[serde(rename = "o")] + pub(crate) ore_index: Option>, + #[serde(rename = "m")] + pub(crate) match_index: Option>, + #[serde(rename = "u")] + pub(crate) unique_index: Option, + + #[serde(rename = "s")] + pub(crate) selector: Option, + + #[serde(rename = "b")] + pub(crate) blake3_index: Option, + + #[serde(rename = "ocf")] + pub(crate) ore_cclw_fixed_index: Option, + #[serde(rename = "ocv")] + pub(crate) ore_cclw_var_index: Option, + + #[serde(rename = "sv")] + pub(crate) ste_vec_index: Option>, +} #[cfg(test)] mod tests { + use crate::{ + eql::{EqlEncryptedBody, EqlEncryptedIndexes}, + EqlEncrypted, + }; + use super::{Identifier, Plaintext}; - use crate::Encrypted; use cipherstash_client::zerokms::EncryptedRecord; use recipher::key::Iv; use uuid::Uuid; @@ -141,20 +142,28 @@ mod tests { pub fn ciphertext_json() { let expected = Identifier::new("table", "column"); - let ct = Encrypted::Ciphertext { + let ct = EqlEncrypted { identifier: expected.clone(), version: 1, - ciphertext: EncryptedRecord { - iv: Iv::default(), - ciphertext: vec![1; 32], - tag: vec![1; 16], - descriptor: "ciphertext".to_string(), - dataset_id: Some(Uuid::new_v4()), + body: EqlEncryptedBody { + ciphertext: EncryptedRecord { + iv: Iv::default(), + ciphertext: vec![1; 32], + tag: vec![1; 16], + descriptor: "ciphertext".to_string(), + dataset_id: Some(Uuid::new_v4()), + }, + indexes: EqlEncryptedIndexes { + ore_index: None, + match_index: None, + unique_index: None, + blake3_index: None, + selector: None, + ore_cclw_fixed_index: None, + ore_cclw_var_index: None, + ste_vec_index: None, + }, }, - - ore_index: None, - match_index: None, - unique_index: None, }; let value = serde_json::to_value(&ct).unwrap(); @@ -163,12 +172,7 @@ mod tests { let t = &i["t"]; assert_eq!(t, "table"); - let result: Encrypted = serde_json::from_value(value).unwrap(); - - if let Encrypted::Ciphertext { identifier, .. } = result { - assert_eq!(expected, identifier); - } else { - panic!("Expected Encrypted::Ciphertext"); - } + let result: EqlEncrypted = serde_json::from_value(value).unwrap(); + assert_eq!(expected, result.identifier); } } diff --git a/packages/cipherstash-proxy/src/lib.rs b/packages/cipherstash-proxy/src/lib.rs index 202a07e8..2e81ca9e 100644 --- a/packages/cipherstash-proxy/src/lib.rs +++ b/packages/cipherstash-proxy/src/lib.rs @@ -15,7 +15,7 @@ pub use crate::cli::Args; pub use crate::cli::Migrate; pub use crate::config::{DatabaseConfig, ServerConfig, TandemConfig, TlsConfig}; pub use crate::encrypt::Encrypt; -pub use crate::eql::{Encrypted, ForQuery, Identifier, Plaintext}; +pub use crate::eql::{EqlEncrypted, ForQuery, Identifier, Plaintext}; pub use crate::log::init; use std::mem; diff --git a/packages/cipherstash-proxy/src/postgresql/backend.rs b/packages/cipherstash-proxy/src/postgresql/backend.rs index 76b2e74b..950f5a3d 100644 --- a/packages/cipherstash-proxy/src/postgresql/backend.rs +++ b/packages/cipherstash-proxy/src/postgresql/backend.rs @@ -6,7 +6,7 @@ use super::messages::row_description::RowDescription; use super::messages::BackendCode; use crate::connect::Sender; use crate::encrypt::Encrypt; -use crate::eql::Encrypted; +use crate::eql::EqlEncrypted; use crate::error::Error; use crate::log::{DEVELOPMENT, MAPPER, PROTOCOL}; use crate::postgresql::context::Portal; @@ -262,7 +262,7 @@ where let result_column_format_codes = portal.format_codes(result_column_count); // Each row is converted into Vec> - let ciphertexts: Vec> = rows + let ciphertexts: Vec> = rows .iter() .map(|row| row.to_ciphertext()) .flatten_ok() diff --git a/packages/cipherstash-proxy/src/postgresql/frontend.rs b/packages/cipherstash-proxy/src/postgresql/frontend.rs index 681b7f50..a3e87a3d 100644 --- a/packages/cipherstash-proxy/src/postgresql/frontend.rs +++ b/packages/cipherstash-proxy/src/postgresql/frontend.rs @@ -23,7 +23,7 @@ use crate::prometheus::{ STATEMENTS_ENCRYPTED_TOTAL, STATEMENTS_PASSTHROUGH_TOTAL, STATEMENTS_TOTAL, STATEMENTS_UNMAPPABLE_TOTAL, }; -use crate::Encrypted; +use crate::EqlEncrypted; use bytes::BytesMut; use cipherstash_client::encryption::Plaintext; use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypeCheckedStatement}; @@ -359,7 +359,7 @@ where &mut self, typed_statement: &TypeCheckedStatement<'_>, literal_columns: &Vec>, - ) -> Result>, Error> { + ) -> Result>, Error> { let literal_values = typed_statement.literal_values(); if literal_values.is_empty() { debug!(target: MAPPER, @@ -404,7 +404,7 @@ where async fn transform_statement( &mut self, typed_statement: &TypeCheckedStatement<'_>, - encrypted_literals: &Vec>, + encrypted_literals: &Vec>, ) -> Result, Error> { // Convert literals to ast Expr let mut encrypted_expressions = vec![]; @@ -704,7 +704,7 @@ where &mut self, bind: &Bind, statement: &Statement, - ) -> Result>, Error> { + ) -> Result>, Error> { let plaintexts = bind.to_plaintext(&statement.param_columns, &statement.postgres_param_types)?; diff --git a/packages/cipherstash-proxy/src/postgresql/messages/bind.rs b/packages/cipherstash-proxy/src/postgresql/messages/bind.rs index e259b3b6..a76a02f6 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/bind.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/bind.rs @@ -76,7 +76,7 @@ impl Bind { Ok(plaintexts) } - pub fn rewrite(&mut self, encrypted: Vec>) -> Result<(), Error> { + pub fn rewrite(&mut self, encrypted: Vec>) -> Result<(), Error> { for (idx, ct) in encrypted.iter().enumerate() { if let Some(ct) = ct { let json = serde_json::to_value(ct)?; diff --git a/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs b/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs index e12274fe..f9e4112e 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs @@ -19,7 +19,7 @@ pub struct DataColumn { } impl DataRow { - pub fn to_ciphertext(&self) -> Result>, Error> { + pub fn to_ciphertext(&self) -> Result>, Error> { Ok(self.columns.iter().map(|col| col.into()).collect()) } @@ -159,7 +159,7 @@ impl TryFrom for BytesMut { } } -impl From<&DataColumn> for Option { +impl From<&DataColumn> for Option { fn from(col: &DataColumn) -> Self { debug!(target: MAPPER, data_column = ?col); match col.json_bytes() { diff --git a/packages/eql-mapper/src/inference/infer_type_impls/expr.rs b/packages/eql-mapper/src/inference/infer_type_impls/expr.rs index 451c0d79..001572d6 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/expr.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/expr.rs @@ -158,7 +158,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { | BinaryOperator::HashArrow | BinaryOperator::HashLongArrow | BinaryOperator::AtAt - | BinaryOperator::HashMinus + | BinaryOperator::HashMinus // TODO do not support for EQL | BinaryOperator::AtQuestion | BinaryOperator::Question | BinaryOperator::QuestionAnd diff --git a/packages/eql-mapper/src/lib.rs b/packages/eql-mapper/src/lib.rs index 3bd45186..920d1d58 100644 --- a/packages/eql-mapper/src/lib.rs +++ b/packages/eql-mapper/src/lib.rs @@ -33,6 +33,7 @@ mod test { use super::type_check; use crate::col; use crate::projection; + use crate::test_helpers; use crate::Param; use crate::Schema; use crate::TableResolver; @@ -1351,4 +1352,85 @@ mod test { Err(err) => panic!("type check failed: {err}"), } } + + #[test] + fn jsonb_operator_arrow() { + test_jsonb_operator("->"); + } + + #[test] + fn jsonb_operator_long_arrow() { + test_jsonb_operator("->>"); + } + + #[test] + fn jsonb_operator_hash_arrow() { + test_jsonb_operator("#>"); + } + + #[test] + fn jsonb_operator_hash_long_arrow() { + test_jsonb_operator("#>>"); + } + + #[test] + fn jsonb_operator_hash_at_at() { + test_jsonb_operator("@@"); + } + + #[test] + fn jsonb_operator_at_question() { + test_jsonb_operator("@?"); + } + + #[test] + fn jsonb_operator_question() { + test_jsonb_operator("?"); + } + + #[test] + fn jsonb_operator_question_and() { + test_jsonb_operator("?&"); + } + + #[test] + fn jsonb_operator_question_pipe() { + test_jsonb_operator("?|"); + } + + #[test] + fn jsonb_operator_at_arrow() { + test_jsonb_operator("@>"); + } + + #[test] + fn jsonb_operator_arrow_at() { + test_jsonb_operator("<@"); + } + + fn test_jsonb_operator(op: &'static str) { + let schema = resolver(schema! { + tables: { + patients: { + id (PK), + notes (EQL), + } + } + }); + + let statement = parse(&format!("SELECT id, notes {} 'medications' AS meds FROM patients", op)); + + match type_check(schema, &statement) { + Ok(typed) => { + match typed.transform(test_helpers::dummy_encrypted_json_selector(&typed, "medications")) { + Ok(statement) => assert_eq!( + statement.to_string(), + format!("SELECT id, notes {} '' AS meds FROM patients", op) + ), + Err(err) => panic!("transformation failed: {err}"), + } + } + Err(err) => panic!("type check failed: {err}"), + } + } } diff --git a/packages/eql-mapper/src/model/type_system.rs b/packages/eql-mapper/src/model/type_system.rs index 7ea05ce9..a4ef7fe4 100644 --- a/packages/eql-mapper/src/model/type_system.rs +++ b/packages/eql-mapper/src/model/type_system.rs @@ -61,6 +61,13 @@ impl Projection { Projection::WithColumns(columns) } } + + pub fn type_at_col_index(&self, index: usize) -> Option<&Value> { + match self { + Projection::WithColumns(cols) => cols.get(index).map(|col| &col.ty), + Projection::Empty => None, + } + } } /// A column from a projection which has a type and an optional alias. diff --git a/packages/eql-mapper/src/test_helpers.rs b/packages/eql-mapper/src/test_helpers.rs index cf36ba46..4457bdbe 100644 --- a/packages/eql-mapper/src/test_helpers.rs +++ b/packages/eql-mapper/src/test_helpers.rs @@ -1,16 +1,19 @@ -use std::fmt::Debug; +use std::{collections::HashMap, fmt::Debug}; -use sqltk::parser::{ - ast::{self as ast, Statement}, - dialect::PostgreSqlDialect, - parser::Parser, +use sqltk::{ + parser::{ + ast::{self as ast, Statement}, + dialect::PostgreSqlDialect, + parser::Parser, + }, + NodeKey, }; use tracing_subscriber::fmt::format; use tracing_subscriber::fmt::format::FmtSpan; use std::sync::Once; -use crate::{Projection, ProjectionColumn}; +use crate::{Projection, ProjectionColumn, TypeCheckedStatement}; #[allow(unused)] pub(crate) fn init_tracing() { @@ -27,7 +30,7 @@ pub(crate) fn init_tracing() { }); } -pub(crate) fn parse(statement: &'static str) -> Statement { +pub(crate) fn parse(statement: &str) -> Statement { Parser::parse_sql(&PostgreSqlDialect {}, statement).unwrap()[0].clone() } @@ -35,6 +38,25 @@ pub(crate) fn id(ident: &str) -> ast::Ident { ast::Ident::from(ident) } +pub(crate) fn get_node_key_of_json_selector<'ast>( + typed: &TypeCheckedStatement<'ast>, + selector: &'static str, +) -> NodeKey<'ast> { + typed + .find_nodekey_for_value_node(ast::Value::SingleQuotedString(selector.into())) + .expect("could not find selector Value node") +} + +pub(crate) fn dummy_encrypted_json_selector<'ast>( + typed: &TypeCheckedStatement<'ast>, + selector: &'static str, +) -> HashMap, ast::Value> { + HashMap::from_iter(vec![( + get_node_key_of_json_selector(typed, selector), + ast::Value::SingleQuotedString(format!("", selector)), + )]) +} + #[macro_export] macro_rules! col { ((NATIVE)) => { diff --git a/packages/eql-mapper/src/type_checked_statement.rs b/packages/eql-mapper/src/type_checked_statement.rs index bb66aef8..81758446 100644 --- a/packages/eql-mapper/src/type_checked_statement.rs +++ b/packages/eql-mapper/src/type_checked_statement.rs @@ -1,7 +1,10 @@ +use std::any::TypeId; +use std::convert::Infallible; +use std::ops::ControlFlow; use std::{collections::HashMap, sync::Arc}; -use sqltk::parser::ast::{self, Statement}; -use sqltk::{AsNodeKey, NodeKey, Transformable}; +use sqltk::parser::ast::{self, Query, SetExpr, Statement}; +use sqltk::{AsNodeKey, Break, NodeKey, Transformable, Visitable, Visitor}; use crate::{ DryRunnable, EqlMapperError, EqlValue, FailOnPlaceholderChange, GroupByEqlCol, Param, @@ -81,6 +84,34 @@ impl<'ast> TypeCheckedStatement<'ast> { self.statement.apply_transform(&mut transformer) } + /// Utility for finding the [`NodeKey`] of a [`Value`] node in `statement` by providing a `matching` equal node to search for. + #[cfg(test)] + pub(crate) fn find_nodekey_for_value_node(&self, matching: ast::Value) -> Option> { + struct FindNode<'ast> { + needle: ast::Value, + found: Option>, + } + + impl<'a> Visitor<'a> for FindNode<'a> { + type Error = Infallible; + + fn enter(&mut self, node: &'a N) -> ControlFlow> { + if let Some(haystack) = node.downcast_ref::() { + if haystack == &self.needle { + self.found = Some(haystack.as_node_key()); + return ControlFlow::Break(Break::Finished) + } + } + ControlFlow::Continue(()) + } + } + + let mut visitor = FindNode{ needle: matching, found: None }; + self.statement.accept(&mut visitor); + + visitor.found + } + pub fn literal_values(&self) -> Vec<&sqltk::parser::ast::Value> { self.literals .iter() @@ -113,11 +144,7 @@ impl<'ast> TypeCheckedStatement<'ast> { } for (key, _) in encrypted_literals.iter() { - if !self - .literals - .iter() - .any(|(_, node)| &node.as_node_key() == key) - { + if !self.literal_exists_for_node_key(*key) { return Err(EqlMapperError::Transform(String::from( "encrypted literals refers to a literal node which is not present in the SQL statement" ))); @@ -126,6 +153,12 @@ impl<'ast> TypeCheckedStatement<'ast> { Ok(()) } + fn literal_exists_for_node_key(&self, key: NodeKey<'ast>) -> bool { + self.literals + .iter() + .any(|(_, node)| node.as_node_key() == key) + } + fn count_not_null_literals(&self) -> usize { self.literals .iter() From 1cdb426d2d95a5107731f6dea12efc255d6b3777 Mon Sep 17 00:00:00 2001 From: James Sadler Date: Thu, 1 May 2025 23:27:16 +1000 Subject: [PATCH 4/7] chore(mapper): get test infra in place for JSONB functions --- Cargo.lock | 7 + .../src/generate.rs | 23 +- packages/cipherstash-proxy/src/encrypt/mod.rs | 2 +- packages/eql-mapper/Cargo.toml | 1 + .../inference/infer_type_impls/function.rs | 125 ++-------- .../infer_type_impls/function_arg_expr.rs | 24 ++ .../src/inference/infer_type_impls/mod.rs | 1 + packages/eql-mapper/src/inference/mod.rs | 7 +- .../eql-mapper/src/inference/sql_fn_macros.rs | 30 +++ .../eql-mapper/src/inference/sql_functions.rs | 223 ++++++++++++++++++ packages/eql-mapper/src/lib.rs | 86 ++++++- packages/eql-mapper/src/test_helpers.rs | 64 +++-- .../eql-mapper/src/type_checked_statement.rs | 35 +-- 13 files changed, 451 insertions(+), 177 deletions(-) create mode 100644 packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs create mode 100644 packages/eql-mapper/src/inference/sql_fn_macros.rs create mode 100644 packages/eql-mapper/src/inference/sql_functions.rs diff --git a/Cargo.lock b/Cargo.lock index 21164890..db71334f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1310,6 +1310,7 @@ dependencies = [ "thiserror 2.0.12", "tracing", "tracing-subscriber", + "vec1", ] [[package]] @@ -4320,6 +4321,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "vec1" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eab68b56840f69efb0fefbe3ab6661499217ffdc58e2eef7c3f6f69835386322" + [[package]] name = "version_check" version = "0.9.5" diff --git a/packages/cipherstash-proxy-integration/src/generate.rs b/packages/cipherstash-proxy-integration/src/generate.rs index 49e4bf06..57a4cec3 100644 --- a/packages/cipherstash-proxy-integration/src/generate.rs +++ b/packages/cipherstash-proxy-integration/src/generate.rs @@ -1,22 +1,17 @@ #[cfg(test)] mod tests { - use crate::common::{clear, connect_with_tls, id, trace, PROXY}; + use crate::common::trace; use cipherstash_client::config::EnvSource; use cipherstash_client::credentials::auto_refresh::AutoRefresh; - use cipherstash_client::ejsonpath::Selector; use cipherstash_client::encryption::{ - Encrypted, EncryptedEntry, EncryptedSteVecTerm, JsonIndexer, JsonIndexerOptions, OreTerm, - Plaintext, PlaintextTarget, QueryBuilder, ReferencedPendingPipeline, - }; - use cipherstash_client::{ - encryption::{ScopedCipher, SteVec}, - zerokms::{encrypted_record, EncryptedRecord}, + Encrypted, EncryptedSteVecTerm, JsonIndexer, JsonIndexerOptions, OreTerm, Plaintext, + PlaintextTarget, ReferencedPendingPipeline, }; + use cipherstash_client::{encryption::ScopedCipher, zerokms::EncryptedRecord}; use cipherstash_client::{ConsoleConfig, CtsConfig, ZeroKMSConfig}; use cipherstash_config::column::{Index, IndexType}; - use cipherstash_config::{ColumnConfig, ColumnMode, ColumnType}; + use cipherstash_config::{ColumnConfig, ColumnType}; use cipherstash_proxy::Identifier; - use rustls::unbuffered::EncodeError; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::info; @@ -25,7 +20,7 @@ mod tests { pub mod option_mp_base85 { use cipherstash_client::zerokms::encrypted_record::formats::mp_base85; use cipherstash_client::zerokms::EncryptedRecord; - use serde::{Deserialize, Deserializer, Serialize, Serializer}; + use serde::{Deserialize, Deserializer, Serializer}; pub fn serialize( value: &Option, @@ -167,7 +162,7 @@ mod tests { // let mut value = // serde_json::from_str::("{\"hello\": \"two\", \"n\": 20}").unwrap(); - let mut value = + let value = serde_json::from_str::("{\"hello\": \"two\", \"n\": 30}").unwrap(); // let mut value = @@ -243,12 +238,12 @@ mod tests { let term = indexer.generate_term(term, cipher.index_key()).unwrap(); match term { - EncryptedSteVecTerm::Mac(items) => todo!(), + EncryptedSteVecTerm::Mac(_) => todo!(), EncryptedSteVecTerm::OreFixed(ore_cllw8_v1) => { let term = hex::encode(ore_cllw8_v1.bytes); info!("{n}: {term}"); } - EncryptedSteVecTerm::OreVariable(ore_cllw8_variable_v1) => todo!(), + EncryptedSteVecTerm::OreVariable(_) => todo!(), } // if let Some(ste_vec_index) = e.ste_vec_index { diff --git a/packages/cipherstash-proxy/src/encrypt/mod.rs b/packages/cipherstash-proxy/src/encrypt/mod.rs index 30bf52c9..24bf76d4 100644 --- a/packages/cipherstash-proxy/src/encrypt/mod.rs +++ b/packages/cipherstash-proxy/src/encrypt/mod.rs @@ -271,7 +271,7 @@ fn to_eql_encrypted( ore_cclw_var_index = Some(hex::encode(ore.bytes)) } }, - IndexTerm::SteQueryVec(query) => {} // TODO: what do we do here? + IndexTerm::SteQueryVec(_query) => {} // TODO: what do we do here? IndexTerm::Null => {} }; } diff --git a/packages/eql-mapper/Cargo.toml b/packages/eql-mapper/Cargo.toml index febb293e..3c63989e 100644 --- a/packages/eql-mapper/Cargo.toml +++ b/packages/eql-mapper/Cargo.toml @@ -19,6 +19,7 @@ sqltk = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } +vec1 = "1.12.1" [dev-dependencies] pretty_assertions = "^1.0" diff --git a/packages/eql-mapper/src/inference/infer_type_impls/function.rs b/packages/eql-mapper/src/inference/infer_type_impls/function.rs index 88347df6..585d6b7e 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/function.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/function.rs @@ -1,12 +1,16 @@ use eql_mapper_macros::trace_infer; -use sqltk::parser::ast::{Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident}; +use sqltk::parser::ast::{Function, FunctionArguments}; use crate::{ - inference::{type_error::TypeError, InferType}, - unifier::Type, - SqlIdent, TypeInferencer, + get_type_signature_for_special_cased_sql_function, inference::infer_type::InferType, + CompoundIdent, FunctionSig, TypeError, TypeInferencer, }; +/// Looks up the function signature. +/// +/// If a signature is found it means that function is handled as an EQL special case and is type checked accordingly. +/// +/// If a signature is not found then all function args and its return type are unified as native. #[trace_infer] impl<'ast> InferType<'ast, Function> for TypeInferencer<'ast> { fn infer_exit(&mut self, function: &'ast Function) -> Result<(), TypeError> { @@ -17,115 +21,14 @@ impl<'ast> InferType<'ast, Function> for TypeInferencer<'ast> { } let Function { name, args, .. } = function; + let fn_name = CompoundIdent::from(&name.0); - let fn_name: Vec<_> = name.0.iter().map(SqlIdent).collect(); - - if fn_name == [SqlIdent(&Ident::new("min"))] || fn_name == [SqlIdent(&Ident::new("max"))] { - // 1. There MUST be one unnamed argument (it CAN come from a subquery) - // 2. The return type is the same as the argument type - - match args { - FunctionArguments::None => { - return Err(TypeError::FunctionCall(format!( - "{} should be called with 1 argument, got 0", - fn_name.last().unwrap() - ))) - } - - FunctionArguments::Subquery(query) => { - // The query must return a single column projection which has the same type as the result of the - // call to min/max. - self.unify_node_with_type( - &**query, - Type::projection(&[(self.get_node_type(function), None)]), - )?; - } - - FunctionArguments::List(args_list) => { - if args_list.args.len() == 1 { - match &args_list.args[0] { - FunctionArg::Named { .. } | FunctionArg::ExprNamed { .. } => { - return Err(TypeError::FunctionCall(format!( - "{} cannot be called with named arguments", - fn_name.last().unwrap(), - ))) - } - - FunctionArg::Unnamed(function_arg_expr) => match function_arg_expr { - FunctionArgExpr::Expr(expr) => { - self.unify_nodes(function, expr)?; - } - - FunctionArgExpr::QualifiedWildcard(_) - | FunctionArgExpr::Wildcard => { - return Err(TypeError::FunctionCall(format!( - "{} cannot be called with wildcard arguments", - fn_name.last().unwrap(), - ))) - } - }, - } - } else { - return Err(TypeError::FunctionCall(format!( - "{} should be called with 1 argument, got {}", - fn_name.last().unwrap(), - args_list.args.len() - ))); - } - } + match get_type_signature_for_special_cased_sql_function(&fn_name, args) { + Some(sig) => { + sig.instantiate(&*self).apply_constraints(self, function)?; } - } else { - // All other functions: resolve to native - // EQL values will be rejected in function calls - self.unify_node_with_type(function, Type::any_native())?; - - match args { - // Function called without any arguments. - // Used for functions like `CURRENT_TIMESTAMP` that do not require parentheses () - // This is not the same as a function that has zero arguments (which would be an empty arg list) - FunctionArguments::None => {} - - FunctionArguments::Subquery(query) => { - // The query must return a single column projection which has the same type as the result of the function - self.unify_node_with_type( - &**query, - Type::projection(&[(self.get_node_type(function), None)]), - )?; - } - - FunctionArguments::List(args_list) => { - self.unify_node_with_type(function, Type::any_native())?; - for arg in &args_list.args { - match arg { - FunctionArg::ExprNamed { - name, - arg, - operator: _, - } => { - self.unify_node_with_type(name, Type::any_native())?; - match arg { - FunctionArgExpr::Expr(expr) => { - self.unify_node_with_type(expr, Type::any_native())?; - } - // Aggregate functions like COUNT(table.*) - FunctionArgExpr::QualifiedWildcard(_) => {} - // Aggregate functions like COUNT(*) - FunctionArgExpr::Wildcard => {} - } - } - FunctionArg::Named { arg, .. } | FunctionArg::Unnamed(arg) => match arg - { - FunctionArgExpr::Expr(expr) => { - self.unify_node_with_type(expr, Type::any_native())?; - } - // Aggregate functions like COUNT(table.*) - FunctionArgExpr::QualifiedWildcard(_) => {} - // Aggregate functions like COUNT(*) - FunctionArgExpr::Wildcard => {} - }, - } - } - } + None => { + FunctionSig::instantiate_native(function).apply_constraints(self, function)?; } } diff --git a/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs b/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs new file mode 100644 index 00000000..89e5d5df --- /dev/null +++ b/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs @@ -0,0 +1,24 @@ +use eql_mapper_macros::trace_infer; +use sqltk::parser::ast::FunctionArgExpr; + +use crate::{inference::infer_type::InferType, TypeError, TypeInferencer}; + +#[trace_infer] +impl<'ast> InferType<'ast, FunctionArgExpr> for TypeInferencer<'ast> { + fn infer_exit(&mut self, farg_expr: &'ast FunctionArgExpr) -> Result<(), TypeError> { + let farg_expr_ty = self.get_node_type(farg_expr); + match farg_expr { + FunctionArgExpr::Expr(expr) => { + self.unify(farg_expr_ty, self.get_node_type(expr))?; + } + FunctionArgExpr::QualifiedWildcard(qualified) => { + self.unify(farg_expr_ty, self.resolve_qualified_wildcard(&qualified.0)?)?; + } + FunctionArgExpr::Wildcard => { + self.unify(farg_expr_ty, self.resolve_wildcard()?)?; + } + }; + + Ok(()) + } +} diff --git a/packages/eql-mapper/src/inference/infer_type_impls/mod.rs b/packages/eql-mapper/src/inference/infer_type_impls/mod.rs index 1f266dee..103a8cfd 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/mod.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/mod.rs @@ -1,6 +1,7 @@ // General AST nodes mod expr; mod function; +mod function_arg_expr; mod select; mod select_item; mod select_items; diff --git a/packages/eql-mapper/src/inference/mod.rs b/packages/eql-mapper/src/inference/mod.rs index 2e364944..b5a96a15 100644 --- a/packages/eql-mapper/src/inference/mod.rs +++ b/packages/eql-mapper/src/inference/mod.rs @@ -2,6 +2,8 @@ mod infer_type; mod infer_type_impls; mod registry; mod sequence; +mod sql_fn_macros; +mod sql_functions; mod type_error; pub mod unifier; @@ -12,7 +14,8 @@ use std::{cell::RefCell, fmt::Debug, marker::PhantomData, ops::ControlFlow, rc:: use infer_type::InferType; use sqltk::parser::ast::{ - Delete, Expr, Function, Ident, Insert, Query, Select, SelectItem, SetExpr, Statement, Values, + Delete, Expr, Function, FunctionArgExpr, Ident, Insert, Query, Select, SelectItem, SetExpr, + Statement, Values, }; use sqltk::{into_control_flow, AsNodeKey, Break, Visitable, Visitor}; @@ -20,6 +23,7 @@ use crate::{ScopeError, ScopeTracker, TableResolver}; pub(crate) use registry::*; pub(crate) use sequence::*; +pub(crate) use sql_functions::*; pub(crate) use type_error::*; /// [`Visitor`] implementation that performs type inference on AST nodes. @@ -187,6 +191,7 @@ macro_rules! dispatch_all { dispatch!($self, $method, $node, Vec); dispatch!($self, $method, $node, SelectItem); dispatch!($self, $method, $node, Function); + dispatch!($self, $method, $node, FunctionArgExpr); dispatch!($self, $method, $node, Values); dispatch!($self, $method, $node, sqltk::parser::ast::Value); }; diff --git a/packages/eql-mapper/src/inference/sql_fn_macros.rs b/packages/eql-mapper/src/inference/sql_fn_macros.rs new file mode 100644 index 00000000..8b09734b --- /dev/null +++ b/packages/eql-mapper/src/inference/sql_fn_macros.rs @@ -0,0 +1,30 @@ +#[macro_export] +macro_rules! to_kind { + (NATIVE) => { + crate::Kind::Native + }; + ($generic:ident) => { + crate::Kind::Generic(stringify!($generic)) + }; +} + +#[macro_export] +macro_rules! sql_fn_args { + (()) => { vec![] }; + + (($arg:ident)) => { vec![crate::to_kind!($arg)] }; + + (($arg:ident $(,$rest:ident)*)) => { + vec![crate::to_kind!($arg) $(,crate::to_kind!($rest))*] + }; +} + +#[macro_export] +macro_rules! sql_fn { + ($name:ident $args:tt -> $return_kind:ident) => { + crate::SqlFunction::new( + stringify!($name), + FunctionSig::new(crate::sql_fn_args!($args), crate::to_kind!($return_kind)), + ) + }; +} \ No newline at end of file diff --git a/packages/eql-mapper/src/inference/sql_functions.rs b/packages/eql-mapper/src/inference/sql_functions.rs new file mode 100644 index 00000000..95b4277a --- /dev/null +++ b/packages/eql-mapper/src/inference/sql_functions.rs @@ -0,0 +1,223 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, LazyLock}, +}; + +use derive_more::derive::Display; +use sqltk::parser::ast::{Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident}; + +use itertools::Itertools; +use vec1::{vec1, Vec1}; + +use crate::{sql_fn, unifier::Type, SqlIdent, TypeInferencer}; + +use super::TypeError; + +#[derive(Debug)] +pub(crate) struct SqlFunction(CompoundIdent, FunctionSig); + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub(crate) enum Kind { + Native, + Generic(&'static str), +} + +#[derive(Debug, Clone)] +pub(crate) struct FunctionSig { + args: Vec, + return_type: Kind, + generics: HashSet<&'static str>, +} + +#[derive(Debug, Clone)] +pub(crate) struct InstantiatedSig { + args: Vec>, + return_type: Arc, +} + +impl FunctionSig { + fn new(args: Vec, return_type: Kind) -> Self { + let mut generics: HashSet<&'static str> = HashSet::new(); + + for arg in &args { + if let Kind::Generic(generic) = arg { + generics.insert(*generic); + } + } + + if let Kind::Generic(generic) = return_type { + generics.insert(generic); + } + + Self { + args, + return_type, + generics, + } + } + + pub(crate) fn is_applicable_to_args(&self, fn_args_syntax: &FunctionArguments) -> bool { + match fn_args_syntax { + FunctionArguments::None => self.args.is_empty(), + FunctionArguments::Subquery(_) => self.args.len() == 1, + FunctionArguments::List(fn_args) => self.args.len() == fn_args.args.len(), + } + } + + pub(crate) fn instantiate(&self, inferencer: &TypeInferencer<'_>) -> InstantiatedSig { + let mut generics: HashMap<&'static str, Arc> = HashMap::new(); + + for generic in self.generics.iter() { + generics.insert(generic, inferencer.fresh_tvar()); + } + + InstantiatedSig { + args: self + .args + .iter() + .map(|kind| match kind { + Kind::Native => Arc::new(Type::any_native()), + Kind::Generic(generic) => generics[generic].clone(), + }) + .collect(), + + return_type: match self.return_type { + Kind::Native => Arc::new(Type::any_native()), + Kind::Generic(generic) => generics[generic].clone(), + }, + } + } + + pub(crate) fn instantiate_native(function: &Function) -> InstantiatedSig { + let arg_count = match &function.args { + FunctionArguments::None => 0, + FunctionArguments::Subquery(_) => 1, + FunctionArguments::List(args) => args.args.len(), + }; + + let args: Vec> = (0..arg_count) + .into_iter() + .map(|_| Arc::new(Type::any_native())) + .collect(); + + InstantiatedSig { + args, + return_type: Arc::new(Type::any_native()), + } + } +} + +impl InstantiatedSig { + pub(crate) fn apply_constraints<'ast>( + &self, + inferencer: &mut TypeInferencer<'ast>, + function: &'ast Function, + ) -> Result<(), TypeError> { + let fn_name = CompoundIdent::from(&function.name.0); + + // let function_ty = inferencer.get_node_type(function); + + inferencer.unify_node_with_type(function, self.return_type.clone())?; + + match &function.args { + FunctionArguments::None => { + if self.args.len() == 0 { + Ok(()) + } else { + Err(TypeError::Conflict(format!( + "expected {} args to function {}; got 0", + self.args.len(), + fn_name + ))) + } + } + + FunctionArguments::Subquery(query) => { + if self.args.len() == 1 { + inferencer.unify_node_with_type(&**query, self.args[0].clone())?; + Ok(()) + } else { + Err(TypeError::Conflict(format!( + "expected {} args to function {}; got 0", + self.args.len(), + fn_name + ))) + } + } + + FunctionArguments::List(args) => { + for (sig_arg, fn_arg) in self.args.iter().zip(args.args.iter()) { + let farg_expr = get_function_arg_expr(fn_arg); + inferencer.unify_node_with_type(farg_expr, sig_arg.clone())?; + } + + Ok(()) + } + } + } +} + +fn get_function_arg_expr(fn_arg: &FunctionArg) -> &FunctionArgExpr { + match fn_arg { + FunctionArg::Named { arg, .. } => arg, + FunctionArg::ExprNamed { arg, .. } => arg, + FunctionArg::Unnamed(arg) => arg, + } +} + +impl SqlFunction { + fn new(ident: &str, sig: FunctionSig) -> Self { + Self(CompoundIdent::from(ident), sig) + } +} + +#[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Clone, Display)] +#[display("{}", _0.iter().map(SqlIdent::to_string).collect::>().join("."))] +pub(crate) struct CompoundIdent(Vec1>); + +impl From<&str> for CompoundIdent { + fn from(value: &str) -> Self { + CompoundIdent(vec1![SqlIdent(Ident::new(value))]) + } +} + +impl From<&Vec> for CompoundIdent { + fn from(value: &Vec) -> Self { + let mut idents = Vec1::>::new(SqlIdent(value[0].clone())); + idents.extend(value[1..].into_iter().cloned().map(SqlIdent)); + CompoundIdent(idents) + } +} + +static SQL_FUNCTION_SIGNATURES: LazyLock>> = LazyLock::new(|| { + // Notation: a single uppercase letter denotes an unknown type. Matching letters in a signature will be assigned + // *the same type variable* and thus must resolve to the same type. (🙏 Haskell) + // + // Eventually we should type check EQL types against their configured indexes instead of leaving that to the EQL + // extension in the database. I can imagine supporting type bounds in signatures here, such as: `T: Eq` + let sql_fns = vec![ + sql_fn!(count(T) -> NATIVE), + sql_fn!(min(T) -> T), + sql_fn!(max(T) -> T), + sql_fn!(jsonb_path_query(T, T) -> T), + ]; + + let mut sql_fns_by_name: HashMap> = HashMap::new(); + + for (key, chunk) in &sql_fns.into_iter().chunk_by(|sql_fn| sql_fn.0.clone()) { + sql_fns_by_name.insert( + key.clone(), + chunk.into_iter().map(|sql_fn| sql_fn.1).collect(), + ); + } + + sql_fns_by_name +}); + +pub(crate) fn get_type_signature_for_special_cased_sql_function( + fn_name: &CompoundIdent, + args: &FunctionArguments, +) -> Option<&'static FunctionSig> { + let sigs = SQL_FUNCTION_SIGNATURES.get(fn_name)?; + sigs.iter().find(|sig| sig.is_applicable_to_args(args)) +} diff --git a/packages/eql-mapper/src/lib.rs b/packages/eql-mapper/src/lib.rs index 920d1d58..a4d7437d 100644 --- a/packages/eql-mapper/src/lib.rs +++ b/packages/eql-mapper/src/lib.rs @@ -42,9 +42,11 @@ mod test { Value, }; use pretty_assertions::assert_eq; + use sqltk::parser::ast::Ident; use sqltk::parser::ast::Statement; use sqltk::parser::ast::{self as ast}; use sqltk::AsNodeKey; + use sqltk::NodeKey; use std::collections::HashMap; use std::sync::Arc; use tracing::error; @@ -1408,7 +1410,82 @@ mod test { test_jsonb_operator("<@"); } - fn test_jsonb_operator(op: &'static str) { + #[test] + fn jsonb_function_jsonb_path_query() { + test_jsonb_function( + "jsonb_path_query", + vec![ + ast::Expr::Identifier(Ident::new("notes")), + ast::Expr::Value(ast::Value::SingleQuotedString("$.medications".to_owned())), + ], + ); + } + + // TODO: do we need to check that the RHS of JSON operators MUST be a Value node + // and not an arbitrary expression? + + fn test_jsonb_function(fn_name: &str, args: Vec) { + let schema = resolver(schema! { + tables: { + patients: { + id (PK), + notes (EQL), + } + } + }); + + let args_in = args + .iter() + .map(|expr| expr.to_string()) + .collect::>() + .join(", "); + + let statement = parse(&format!( + "SELECT id, {}({}) AS meds FROM patients", + fn_name, args_in + )); + + let args_encrypted = args + .iter() + .map(|expr| match expr { + ast::Expr::Identifier(ident) => ident.to_string(), + ast::Expr::Value(ast::Value::SingleQuotedString(s)) => { + format!("''", s) + } + _ => panic!("unsupported expr type in test util"), + }) + .collect::>() + .join(", "); + + let mut encrypted_literals: HashMap, ast::Value> = HashMap::new(); + + for arg in args.iter() { + if let ast::Expr::Value(value) = arg { + encrypted_literals.extend(test_helpers::dummy_encrypted_json_selector( + &statement, + value.clone(), + )); + } + } + + match type_check(schema, &statement) { + Ok(typed) => match typed.transform(encrypted_literals) { + Ok(statement) => { + assert_eq!( + statement.to_string(), + format!( + "SELECT id, {}({}) AS meds FROM patients", + fn_name, args_encrypted + ) + ) + } + Err(err) => panic!("transformation failed: {err}"), + }, + Err(err) => panic!("type check failed: {err}"), + } + } + + fn test_jsonb_operator(op: &str) { let schema = resolver(schema! { tables: { patients: { @@ -1418,11 +1495,14 @@ mod test { } }); - let statement = parse(&format!("SELECT id, notes {} 'medications' AS meds FROM patients", op)); + let statement = parse(&format!( + "SELECT id, notes {} 'medications' AS meds FROM patients", + op + )); match type_check(schema, &statement) { Ok(typed) => { - match typed.transform(test_helpers::dummy_encrypted_json_selector(&typed, "medications")) { + match typed.transform(test_helpers::dummy_encrypted_json_selector(&statement, ast::Value::SingleQuotedString("medications".to_owned()))) { Ok(statement) => assert_eq!( statement.to_string(), format!("SELECT id, notes {} '' AS meds FROM patients", op) diff --git a/packages/eql-mapper/src/test_helpers.rs b/packages/eql-mapper/src/test_helpers.rs index 4457bdbe..5af2fde7 100644 --- a/packages/eql-mapper/src/test_helpers.rs +++ b/packages/eql-mapper/src/test_helpers.rs @@ -1,19 +1,19 @@ -use std::{collections::HashMap, fmt::Debug}; +use std::{collections::HashMap, convert::Infallible, fmt::Debug, ops::ControlFlow}; use sqltk::{ parser::{ - ast::{self as ast, Statement}, + ast::{self as ast, Statement, Value}, dialect::PostgreSqlDialect, parser::Parser, }, - NodeKey, + AsNodeKey, Break, NodeKey, Visitable, Visitor, }; use tracing_subscriber::fmt::format; use tracing_subscriber::fmt::format::FmtSpan; use std::sync::Once; -use crate::{Projection, ProjectionColumn, TypeCheckedStatement}; +use crate::{Projection, ProjectionColumn}; #[allow(unused)] pub(crate) fn init_tracing() { @@ -39,24 +39,60 @@ pub(crate) fn id(ident: &str) -> ast::Ident { } pub(crate) fn get_node_key_of_json_selector<'ast>( - typed: &TypeCheckedStatement<'ast>, - selector: &'static str, + statement: &'ast Statement, + selector: &Value, ) -> NodeKey<'ast> { - typed - .find_nodekey_for_value_node(ast::Value::SingleQuotedString(selector.into())) + find_nodekey_for_value_node(statement, selector.clone()) .expect("could not find selector Value node") } pub(crate) fn dummy_encrypted_json_selector<'ast>( - typed: &TypeCheckedStatement<'ast>, - selector: &'static str, + statement: &'ast Statement, + selector: Value, ) -> HashMap, ast::Value> { - HashMap::from_iter(vec![( - get_node_key_of_json_selector(typed, selector), - ast::Value::SingleQuotedString(format!("", selector)), - )]) + if let Value::SingleQuotedString(s) = &selector { + return HashMap::from_iter(vec![( + get_node_key_of_json_selector(statement, &selector), + ast::Value::SingleQuotedString(format!("", s)), + )]) + } else { + panic!("dummy_encrypted_json_selector only works on Value::SingleQuotedString") + } } +/// Utility for finding the [`NodeKey`] of a [`Value`] node in `statement` by providing a `matching` equal node to search for. +pub(crate) fn find_nodekey_for_value_node<'ast>( + statement: &'ast Statement, + matching: ast::Value, +) -> Option> { + struct FindNode<'ast> { + needle: ast::Value, + found: Option>, + } + + impl<'a> Visitor<'a> for FindNode<'a> { + type Error = Infallible; + + fn enter(&mut self, node: &'a N) -> ControlFlow> { + if let Some(haystack) = node.downcast_ref::() { + if haystack == &self.needle { + self.found = Some(haystack.as_node_key()); + return ControlFlow::Break(Break::Finished); + } + } + ControlFlow::Continue(()) + } + } + + let mut visitor = FindNode { + needle: matching, + found: None, + }; + + statement.accept(&mut visitor); + + visitor.found +} #[macro_export] macro_rules! col { ((NATIVE)) => { diff --git a/packages/eql-mapper/src/type_checked_statement.rs b/packages/eql-mapper/src/type_checked_statement.rs index 81758446..43012b7d 100644 --- a/packages/eql-mapper/src/type_checked_statement.rs +++ b/packages/eql-mapper/src/type_checked_statement.rs @@ -1,10 +1,7 @@ -use std::any::TypeId; -use std::convert::Infallible; -use std::ops::ControlFlow; use std::{collections::HashMap, sync::Arc}; -use sqltk::parser::ast::{self, Query, SetExpr, Statement}; -use sqltk::{AsNodeKey, Break, NodeKey, Transformable, Visitable, Visitor}; +use sqltk::parser::ast::{self, Statement}; +use sqltk::{AsNodeKey, NodeKey, Transformable}; use crate::{ DryRunnable, EqlMapperError, EqlValue, FailOnPlaceholderChange, GroupByEqlCol, Param, @@ -84,34 +81,6 @@ impl<'ast> TypeCheckedStatement<'ast> { self.statement.apply_transform(&mut transformer) } - /// Utility for finding the [`NodeKey`] of a [`Value`] node in `statement` by providing a `matching` equal node to search for. - #[cfg(test)] - pub(crate) fn find_nodekey_for_value_node(&self, matching: ast::Value) -> Option> { - struct FindNode<'ast> { - needle: ast::Value, - found: Option>, - } - - impl<'a> Visitor<'a> for FindNode<'a> { - type Error = Infallible; - - fn enter(&mut self, node: &'a N) -> ControlFlow> { - if let Some(haystack) = node.downcast_ref::() { - if haystack == &self.needle { - self.found = Some(haystack.as_node_key()); - return ControlFlow::Break(Break::Finished) - } - } - ControlFlow::Continue(()) - } - } - - let mut visitor = FindNode{ needle: matching, found: None }; - self.statement.accept(&mut visitor); - - visitor.found - } - pub fn literal_values(&self) -> Vec<&sqltk::parser::ast::Value> { self.literals .iter() From f40e92ec898ca9d7ed2eb7eb8f046b3a4cdaa02f Mon Sep 17 00:00:00 2001 From: James Sadler Date: Mon, 5 May 2025 09:38:55 +1000 Subject: [PATCH 5/7] fix(mapper): hash function for SqlIdent must take quote style into account --- packages/eql-mapper/src/lib.rs | 4 ++-- packages/eql-mapper/src/model/sql_ident.rs | 28 ++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/packages/eql-mapper/src/lib.rs b/packages/eql-mapper/src/lib.rs index a4d7437d..f9694a0c 100644 --- a/packages/eql-mapper/src/lib.rs +++ b/packages/eql-mapper/src/lib.rs @@ -1339,14 +1339,14 @@ mod test { }); let statement = - parse("SELECT MIN(salary), MAX(salary), department FROM employees GROUP BY department"); + parse("SELECT min(salary), max(salary), department FROM employees GROUP BY department"); match type_check(schema, &statement) { Ok(typed) => { match typed.transform(HashMap::new()) { Ok(statement) => assert_eq!( statement.to_string(), - "SELECT CS_MIN_V1(salary) AS MIN, CS_MAX_V1(salary) AS MAX, department FROM employees GROUP BY department".to_string() + "SELECT CS_MIN_V1(salary) AS min, CS_MAX_V1(salary) AS max, department FROM employees GROUP BY department".to_string() ), Err(err) => panic!("transformation failed: {err}"), } diff --git a/packages/eql-mapper/src/model/sql_ident.rs b/packages/eql-mapper/src/model/sql_ident.rs index acaa91da..63bdc9eb 100644 --- a/packages/eql-mapper/src/model/sql_ident.rs +++ b/packages/eql-mapper/src/model/sql_ident.rs @@ -102,14 +102,28 @@ impl SqlIdent { } } -// This manual Hash implementation is required to prevent a clippy error: -// "error: you are deriving `Hash` but have implemented `PartialEq` explicitly" -impl Hash for SqlIdent -where - T: Hash, -{ +// This Hash implementation (and the following) one is required in order to be consistent with PartialEq. +impl Hash for SqlIdent<&Ident> { + fn hash(&self, state: &mut H) { + match self.0.quote_style { + Some(ch) => { + state.write_u8(1); + state.write_u32(ch as u32); + state.write(self.0.value.as_bytes()); + }, + None => { + state.write_u8(0); + for ch in self.0.value.chars().map(|ch| ch.to_lowercase()).flatten() { + state.write_u32(ch as u32); + } + }, + } + } +} + +impl Hash for SqlIdent { fn hash(&self, state: &mut H) { - self.0.hash(state) + SqlIdent(&self.0).hash(state) } } From c27ac3c90324f1117c7b0cab6f8c3f297be69a12 Mon Sep 17 00:00:00 2001 From: James Sadler Date: Mon, 5 May 2025 10:36:18 +1000 Subject: [PATCH 6/7] docs: rustdoc FunctionSig etc --- .../eql-mapper/src/inference/sql_functions.rs | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/packages/eql-mapper/src/inference/sql_functions.rs b/packages/eql-mapper/src/inference/sql_functions.rs index 95b4277a..ec8e0d4d 100644 --- a/packages/eql-mapper/src/inference/sql_functions.rs +++ b/packages/eql-mapper/src/inference/sql_functions.rs @@ -13,15 +13,23 @@ use crate::{sql_fn, unifier::Type, SqlIdent, TypeInferencer}; use super::TypeError; +/// The identifier and type signature of a SQL function. +/// +/// See [`SQL_FUNCTION_SIGNATURES`]. #[derive(Debug)] pub(crate) struct SqlFunction(CompoundIdent, FunctionSig); +/// A representation of the type of an argument or return type in a SQL function. #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub(crate) enum Kind { + /// A type that mjust be a native type Native, + + /// A type that can be a native or EQL type. The `str` is the generic variable name. Generic(&'static str), } +/// The type signature of a SQL functon (excluding its name). #[derive(Debug, Clone)] pub(crate) struct FunctionSig { args: Vec, @@ -29,6 +37,8 @@ pub(crate) struct FunctionSig { generics: HashSet<&'static str>, } +/// A function signature but filled in with fresh type variables that correspond with the [`Kind`] or each argument and +/// return type. #[derive(Debug, Clone)] pub(crate) struct InstantiatedSig { args: Vec>, @@ -56,6 +66,7 @@ impl FunctionSig { } } + /// Checks if `self` is applicable to a particular piece of SQL function invocation syntax. pub(crate) fn is_applicable_to_args(&self, fn_args_syntax: &FunctionArguments) -> bool { match fn_args_syntax { FunctionArguments::None => self.args.is_empty(), @@ -64,6 +75,7 @@ impl FunctionSig { } } + /// Creates an [`InstantiatedSig`] from `self`, filling in the [`Kind`]s with fresh type variables. pub(crate) fn instantiate(&self, inferencer: &TypeInferencer<'_>) -> InstantiatedSig { let mut generics: HashMap<&'static str, Arc> = HashMap::new(); @@ -88,6 +100,8 @@ impl FunctionSig { } } + /// For functions that do not have special case handling we synthesise an [`InstatiatedSig`] from the SQL function + /// invocation synta where all arguments and the return types are native. pub(crate) fn instantiate_native(function: &Function) -> InstantiatedSig { let arg_count = match &function.args { FunctionArguments::None => 0, @@ -108,6 +122,7 @@ impl FunctionSig { } impl InstantiatedSig { + /// Applies the type constraints of the function to to the AST. pub(crate) fn apply_constraints<'ast>( &self, inferencer: &mut TypeInferencer<'ast>, @@ -115,8 +130,6 @@ impl InstantiatedSig { ) -> Result<(), TypeError> { let fn_name = CompoundIdent::from(&function.name.0); - // let function_ty = inferencer.get_node_type(function); - inferencer.unify_node_with_type(function, self.return_type.clone())?; match &function.args { @@ -189,6 +202,7 @@ impl From<&Vec> for CompoundIdent { } } +/// SQL functions that are handled with special case type checking rules. static SQL_FUNCTION_SIGNATURES: LazyLock>> = LazyLock::new(|| { // Notation: a single uppercase letter denotes an unknown type. Matching letters in a signature will be assigned // *the same type variable* and thus must resolve to the same type. (🙏 Haskell) @@ -200,6 +214,8 @@ static SQL_FUNCTION_SIGNATURES: LazyLock sql_fn!(min(T) -> T), sql_fn!(max(T) -> T), sql_fn!(jsonb_path_query(T, T) -> T), + sql_fn!(jsonb_path_query_first(T, T) -> T), + sql_fn!(jsonb_path_exists(T, T) -> T), ]; let mut sql_fns_by_name: HashMap> = HashMap::new(); From 95d9f8dd1965938e398c39bb0361db5410d7ceff Mon Sep 17 00:00:00 2001 From: James Sadler Date: Mon, 5 May 2025 11:58:26 +1000 Subject: [PATCH 7/7] chore: fmt & clippy --- .../eql-mapper/src/inference/sql_fn_macros.rs | 14 ++--- .../eql-mapper/src/inference/sql_functions.rs | 56 +++++++++---------- packages/eql-mapper/src/model/sql_ident.rs | 8 +-- packages/eql-mapper/src/test_helpers.rs | 14 ++--- 4 files changed, 46 insertions(+), 46 deletions(-) diff --git a/packages/eql-mapper/src/inference/sql_fn_macros.rs b/packages/eql-mapper/src/inference/sql_fn_macros.rs index 8b09734b..4fcd9bc7 100644 --- a/packages/eql-mapper/src/inference/sql_fn_macros.rs +++ b/packages/eql-mapper/src/inference/sql_fn_macros.rs @@ -1,10 +1,10 @@ #[macro_export] macro_rules! to_kind { (NATIVE) => { - crate::Kind::Native + $crate::Kind::Native }; ($generic:ident) => { - crate::Kind::Generic(stringify!($generic)) + $crate::Kind::Generic(stringify!($generic)) }; } @@ -12,19 +12,19 @@ macro_rules! to_kind { macro_rules! sql_fn_args { (()) => { vec![] }; - (($arg:ident)) => { vec![crate::to_kind!($arg)] }; + (($arg:ident)) => { vec![$crate::to_kind!($arg)] }; (($arg:ident $(,$rest:ident)*)) => { - vec![crate::to_kind!($arg) $(,crate::to_kind!($rest))*] + vec![$crate::to_kind!($arg) $(, $crate::to_kind!($rest))*] }; } #[macro_export] macro_rules! sql_fn { ($name:ident $args:tt -> $return_kind:ident) => { - crate::SqlFunction::new( + $crate::SqlFunction::new( stringify!($name), - FunctionSig::new(crate::sql_fn_args!($args), crate::to_kind!($return_kind)), + FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)), ) }; -} \ No newline at end of file +} diff --git a/packages/eql-mapper/src/inference/sql_functions.rs b/packages/eql-mapper/src/inference/sql_functions.rs index ec8e0d4d..7a2d42d4 100644 --- a/packages/eql-mapper/src/inference/sql_functions.rs +++ b/packages/eql-mapper/src/inference/sql_functions.rs @@ -110,7 +110,6 @@ impl FunctionSig { }; let args: Vec> = (0..arg_count) - .into_iter() .map(|_| Arc::new(Type::any_native())) .collect(); @@ -134,7 +133,7 @@ impl InstantiatedSig { match &function.args { FunctionArguments::None => { - if self.args.len() == 0 { + if self.args.is_empty() { Ok(()) } else { Err(TypeError::Conflict(format!( @@ -197,38 +196,39 @@ impl From<&str> for CompoundIdent { impl From<&Vec> for CompoundIdent { fn from(value: &Vec) -> Self { let mut idents = Vec1::>::new(SqlIdent(value[0].clone())); - idents.extend(value[1..].into_iter().cloned().map(SqlIdent)); + idents.extend(value[1..].iter().cloned().map(SqlIdent)); CompoundIdent(idents) } } /// SQL functions that are handled with special case type checking rules. -static SQL_FUNCTION_SIGNATURES: LazyLock>> = LazyLock::new(|| { - // Notation: a single uppercase letter denotes an unknown type. Matching letters in a signature will be assigned - // *the same type variable* and thus must resolve to the same type. (🙏 Haskell) - // - // Eventually we should type check EQL types against their configured indexes instead of leaving that to the EQL - // extension in the database. I can imagine supporting type bounds in signatures here, such as: `T: Eq` - let sql_fns = vec![ - sql_fn!(count(T) -> NATIVE), - sql_fn!(min(T) -> T), - sql_fn!(max(T) -> T), - sql_fn!(jsonb_path_query(T, T) -> T), - sql_fn!(jsonb_path_query_first(T, T) -> T), - sql_fn!(jsonb_path_exists(T, T) -> T), - ]; - - let mut sql_fns_by_name: HashMap> = HashMap::new(); - - for (key, chunk) in &sql_fns.into_iter().chunk_by(|sql_fn| sql_fn.0.clone()) { - sql_fns_by_name.insert( - key.clone(), - chunk.into_iter().map(|sql_fn| sql_fn.1).collect(), - ); - } +static SQL_FUNCTION_SIGNATURES: LazyLock>> = + LazyLock::new(|| { + // Notation: a single uppercase letter denotes an unknown type. Matching letters in a signature will be assigned + // *the same type variable* and thus must resolve to the same type. (🙏 Haskell) + // + // Eventually we should type check EQL types against their configured indexes instead of leaving that to the EQL + // extension in the database. I can imagine supporting type bounds in signatures here, such as: `T: Eq` + let sql_fns = vec![ + sql_fn!(count(T) -> NATIVE), + sql_fn!(min(T) -> T), + sql_fn!(max(T) -> T), + sql_fn!(jsonb_path_query(T, T) -> T), + sql_fn!(jsonb_path_query_first(T, T) -> T), + sql_fn!(jsonb_path_exists(T, T) -> T), + ]; + + let mut sql_fns_by_name: HashMap> = HashMap::new(); + + for (key, chunk) in &sql_fns.into_iter().chunk_by(|sql_fn| sql_fn.0.clone()) { + sql_fns_by_name.insert( + key.clone(), + chunk.into_iter().map(|sql_fn| sql_fn.1).collect(), + ); + } - sql_fns_by_name -}); + sql_fns_by_name + }); pub(crate) fn get_type_signature_for_special_cased_sql_function( fn_name: &CompoundIdent, diff --git a/packages/eql-mapper/src/model/sql_ident.rs b/packages/eql-mapper/src/model/sql_ident.rs index 63bdc9eb..aa5a0a3b 100644 --- a/packages/eql-mapper/src/model/sql_ident.rs +++ b/packages/eql-mapper/src/model/sql_ident.rs @@ -110,20 +110,20 @@ impl Hash for SqlIdent<&Ident> { state.write_u8(1); state.write_u32(ch as u32); state.write(self.0.value.as_bytes()); - }, + } None => { state.write_u8(0); - for ch in self.0.value.chars().map(|ch| ch.to_lowercase()).flatten() { + for ch in self.0.value.chars().flat_map(|ch| ch.to_lowercase()) { state.write_u32(ch as u32); } - }, + } } } } impl Hash for SqlIdent { fn hash(&self, state: &mut H) { - SqlIdent(&self.0).hash(state) + SqlIdent(&self.0).hash(state) } } diff --git a/packages/eql-mapper/src/test_helpers.rs b/packages/eql-mapper/src/test_helpers.rs index 5af2fde7..1d9285b2 100644 --- a/packages/eql-mapper/src/test_helpers.rs +++ b/packages/eql-mapper/src/test_helpers.rs @@ -46,12 +46,12 @@ pub(crate) fn get_node_key_of_json_selector<'ast>( .expect("could not find selector Value node") } -pub(crate) fn dummy_encrypted_json_selector<'ast>( - statement: &'ast Statement, +pub(crate) fn dummy_encrypted_json_selector( + statement: &Statement, selector: Value, -) -> HashMap, ast::Value> { +) -> HashMap, ast::Value> { if let Value::SingleQuotedString(s) = &selector { - return HashMap::from_iter(vec![( + HashMap::from_iter(vec![( get_node_key_of_json_selector(statement, &selector), ast::Value::SingleQuotedString(format!("", s)), )]) @@ -61,10 +61,10 @@ pub(crate) fn dummy_encrypted_json_selector<'ast>( } /// Utility for finding the [`NodeKey`] of a [`Value`] node in `statement` by providing a `matching` equal node to search for. -pub(crate) fn find_nodekey_for_value_node<'ast>( - statement: &'ast Statement, +pub(crate) fn find_nodekey_for_value_node( + statement: &Statement, matching: ast::Value, -) -> Option> { +) -> Option> { struct FindNode<'ast> { needle: ast::Value, found: Option>,