diff --git a/Cargo.toml b/Cargo.toml index 245ca14..6a46d0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ deltalake-azure = { version = "0.3.0", optional = true } dynamodb_lock = { version = "0.6.0", optional = true } # sentry sentry = { version = "0.23.0", optional = true } +regex = "1.10.2" [features] default = [] @@ -66,6 +67,12 @@ serial_test = "*" tempfile = "3" time = "0.3.20" utime = "0.3" +criterion = "0.5.1" +tar = "0.4" + +[[bench]] +name = "filters" +harness = false [profile.release] lto = true diff --git a/benches/filters.rs b/benches/filters.rs new file mode 100644 index 0000000..c0ddbc3 --- /dev/null +++ b/benches/filters.rs @@ -0,0 +1,56 @@ +use std::fs::File; +use std::io::{self, BufRead, BufReader}; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use serde_json::Value; + +use kafka_delta_ingest::{Filter, FilterEngine, FilterError, FilterFactory}; + +const SOURCE_PATH: &str = "tests/json/web_requests-100.json"; + +fn read_json_file(file_path: &str) -> io::Result> { + let file = File::open(file_path)?; + let reader = BufReader::new(file); + let lines: Vec<_> = reader.lines().take(30000).collect::>()?; + + let values: Vec = lines + .iter() + .map(|line| serde_json::from_str::(&line).unwrap()) + .collect(); + + Ok(values) +} + +fn filtering(filter: &Box, values: &Vec) { + for v in values.into_iter() { + match filter.filter(v) { + Ok(_) => {} + Err(e) => match e { + FilterError::FilterSkipMessage => {} + _ => panic!("something wrong"), + }, + }; + } +} + +fn naive_filter_benchmark(c: &mut Criterion) { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = FilterFactory::try_build(&FilterEngine::Naive, &vec!["method=='GET'".to_string()]) + .expect("wrong"); + c.bench_function("naive_filter_benchmark", |b| { + b.iter(|| filtering(&filter, black_box(&values))) + }); +} + +fn jmespath_filter_benchmark(c: &mut Criterion) { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = + FilterFactory::try_build(&FilterEngine::Jmespath, &vec!["method=='GET'".to_string()]) + .expect("wrong"); + c.bench_function("jmespath_filter_benchmark", |b| { + b.iter(|| filtering(&filter, black_box(&values))) + }); +} + +criterion_group!(benches, naive_filter_benchmark, jmespath_filter_benchmark); +criterion_main!(benches); diff --git a/src/filters/error.rs b/src/filters/error.rs new file mode 100644 index 0000000..a3204e0 --- /dev/null +++ b/src/filters/error.rs @@ -0,0 +1,49 @@ +use jmespatch::JmespathError; + +use crate::filters::naive_filter::error::NaiveFilterError; + +/// Errors returned by filters +#[derive(thiserror::Error, Debug)] +pub enum FilterError { + /// Failed compile expression + #[error("Failed compile expression: {source}")] + CompileExpressionError { + /// Wrapped [JmespathError] + source: JmespathError, + }, + + /// Message does not match filter pattern + #[error("Can't filter message: {source}")] + JmespathError { + /// Wrapped [JmespathError] + #[from] + source: JmespathError, + }, + + /// NaiveFilterError + #[error("NaiveFilter failure: {source}")] + NaiveFilterError { + /// Wrapped [`crate::filters::naive_filter::error::NaiveFilterError`] + #[from] + source: NaiveFilterError, + }, + + /// Error from [`serde_json`] + #[error("JSON serialization failed: {source}")] + SerdeJson { + /// Wrapped [`serde_json::Error`] + #[from] + source: serde_json::Error, + }, + + /// Filter engine not found + #[error("Filter engine not found: {name}")] + NotFound { + /// Wrong name + name: String, + }, + + /// Error returned for skipping message + #[error("Skipped a message by filter")] + FilterSkipMessage, +} diff --git a/src/filters/filter.rs b/src/filters/filter.rs new file mode 100644 index 0000000..d07ce06 --- /dev/null +++ b/src/filters/filter.rs @@ -0,0 +1,15 @@ +use serde_json::Value; + +use crate::filters::FilterError; + +/// Trait for implementing a filter mechanism +pub trait Filter: Send { + /// Constructor + fn from_filters(filters: &[String]) -> Result + where + Self: Sized; + + /// A function that filters a message. If any of the filters fail, it throws an error; + /// if all filters pass, it returns nothing. + fn filter(&self, message: &Value) -> Result<(), FilterError>; +} diff --git a/src/filters/filter_factory.rs b/src/filters/filter_factory.rs new file mode 100644 index 0000000..ce9a387 --- /dev/null +++ b/src/filters/filter_factory.rs @@ -0,0 +1,31 @@ +use crate::filters::{Filter, FilterError, JmespathFilter, NaiveFilter}; + +/// Filter options +#[derive(Clone, Debug)] +pub enum FilterEngine { + /// Filter for simple comparisons that works a little faster + Naive, + /// Filter for complex comparisons + Jmespath, +} + +/// Factory for creating and managing filters +pub struct FilterFactory {} +impl FilterFactory { + /// Factory for creating filter instances + pub fn try_build( + filter_engine: &FilterEngine, + filters: &[String], + ) -> Result, FilterError> { + match filter_engine { + FilterEngine::Naive => match NaiveFilter::from_filters(filters) { + Ok(f) => Ok(Box::new(f)), + Err(e) => Err(e), + }, + FilterEngine::Jmespath => match JmespathFilter::from_filters(filters) { + Ok(f) => Ok(Box::new(f)), + Err(e) => Err(e), + }, + } + } +} diff --git a/src/filters/jmespath_filter/custom_functions.rs b/src/filters/jmespath_filter/custom_functions.rs new file mode 100644 index 0000000..28fd8fd --- /dev/null +++ b/src/filters/jmespath_filter/custom_functions.rs @@ -0,0 +1,41 @@ +use std::convert::TryFrom; +use std::sync::Arc; + +use jmespatch::functions::{ArgumentType, CustomFunction, Signature}; +use jmespatch::{Context, ErrorReason, JmespathError, Rcvar, Variable}; + +/// Custom function to compare two string values in a case-insensitive manner +fn eq_ignore_case(args: &[Rcvar], context: &mut Context) -> Result { + let s = match args[0].as_string() { + None => { + return Err(JmespathError::new( + context.expression, + context.offset, + ErrorReason::Parse("first variable must be string".to_string()), + )) + } + Some(s) => s, + }; + + let p = match args[1].as_string() { + None => { + return Err(JmespathError::new( + context.expression, + context.offset, + ErrorReason::Parse("second variable must be string".to_string()), + )) + } + Some(p) => p, + }; + + let var = Variable::try_from(serde_json::Value::Bool(s.eq_ignore_ascii_case(p)))?; + + Ok(Arc::new(var)) +} + +pub fn create_eq_ignore_case_function() -> CustomFunction { + CustomFunction::new( + Signature::new(vec![ArgumentType::String, ArgumentType::String], None), + Box::new(eq_ignore_case), + ) +} diff --git a/src/filters/jmespath_filter/filter.rs b/src/filters/jmespath_filter/filter.rs new file mode 100644 index 0000000..1fda756 --- /dev/null +++ b/src/filters/jmespath_filter/filter.rs @@ -0,0 +1,175 @@ +use jmespatch::{Expression, Runtime}; +use serde_json::Value; + +use crate::filters::filter::Filter; +use crate::filters::jmespath_filter::custom_functions::create_eq_ignore_case_function; +use crate::filters::FilterError; + +lazy_static! { + static ref FILTER_RUNTIME: Runtime = { + let mut runtime = Runtime::new(); + runtime.register_builtin_functions(); + runtime.register_function("eq_ignore_case", Box::new(create_eq_ignore_case_function())); + runtime + }; +} + +/// Implementation of the [Filter] trait for complex checks, such as checking for +/// the presence of a key in an object or comparing the second value in an array +/// or check array length. +/// More examples: https://jmespath.org/examples.html; https://jmespath.org/tutorial.html +pub struct JmespathFilter { + filters: Vec>, +} + +impl Filter for JmespathFilter { + fn from_filters(filters: &[String]) -> Result { + let filters = filters + .iter() + .map(|f| { + FILTER_RUNTIME + .compile(f) + .map_err(|source| FilterError::CompileExpressionError { source }) + }) + .collect::>, FilterError>>(); + match filters { + Ok(filters) => Ok(Self { filters }), + Err(e) => Err(e), + } + } + + fn filter(&self, message: &Value) -> Result<(), FilterError> { + if self.filters.is_empty() { + return Ok(()); + } + + for filter in &self.filters { + match filter.search(message) { + Err(e) => return Err(FilterError::JmespathError { source: e }), + Ok(v) => { + if !v.as_boolean().unwrap() { + return Err(FilterError::FilterSkipMessage); + } + } + }; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::fs::File; + use std::io; + use std::io::{BufRead, BufReader}; + + use super::*; + + const SOURCE_PATH: &str = "tests/json/web_requests-100.json"; + + fn read_json_file(file_path: &str) -> io::Result> { + let file = File::open(file_path)?; + let reader = BufReader::new(file); + let lines: Vec<_> = reader.lines().take(30000).collect::>()?; + + let values: Vec = lines + .iter() + .map(|line| serde_json::from_str::(&line).unwrap()) + .collect(); + + Ok(values) + } + + fn run_filter(filter: &JmespathFilter, values: &Vec) -> (i32, i32) { + let mut passed_messages = 0; + let mut filtered_messages = 0; + + for v in values.into_iter() { + match filter.filter(&v) { + Ok(_) => passed_messages += 1, + Err(FilterError::FilterSkipMessage) => filtered_messages += 1, + Err(e) => panic!("{}", e), + } + } + + return (passed_messages, filtered_messages); + } + #[test] + fn equal() { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = match JmespathFilter::from_filters(&vec![ + "session_id=='a8a3d0e3-7b4e-4f17-b264-76cb792bdb96'".to_string(), + ]) { + Ok(f) => f, + Err(e) => panic!("{}", e), + }; + + let (passed_messages, filtered_messages) = run_filter(&filter, &values); + + assert_eq!(1, passed_messages); + assert_eq!(99, filtered_messages); + } + #[test] + fn eq_ignore_case() { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = match JmespathFilter::from_filters(&vec![ + "eq_ignore_case(method, 'get')".to_string() + ]) { + Ok(f) => f, + Err(e) => panic!("{}", e), + }; + + let (passed_messages, filtered_messages) = run_filter(&filter, &values); + + assert_eq!(17, passed_messages); + assert_eq!(83, filtered_messages); + } + + #[test] + fn or_condition() { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = match JmespathFilter::from_filters(&vec![ + "(status == `404` || method == 'GET')".to_string(), + ]) { + Ok(f) => f, + Err(e) => panic!("{}", e), + }; + + let (passed_messages, filtered_messages) = run_filter(&filter, &values); + + assert_eq!(25, passed_messages); + assert_eq!(75, filtered_messages); + } + + #[test] + fn complex_condition() { + let buff = r#"{"name": "John Doe", "age": 30, "status": "1", "department": "Engineering"} + {"name": "Jane Smith", "age": 25, "status": "1", "department": "Marketing"} + {"name": "Emily Johnson", "age": 35, "department": "Sales"} + {"name": "Michael Brown", "age": 40, "status": "3", "department": "Engineering"} + {"name": "Sarah Davis", "age": 28, "department": "Marketing"} + {"name": "David Wilson", "age": 22, "department": "Sales"} + {"name": "Laura Martinez", "age": 33, "status": "2", "department": "Engineering"} + {"name": "James Anderson", "age": 45, "department": "Marketing"} + {"name": "Linda Thomas", "age": 50, "department": "Sales"} + {"name": "Robert Jackson", "age": 37, "department": "Engineering"}"#; + + let objects = buff.split("\n").map(|s| s.trim()).collect::>(); + let values: Vec = objects + .iter() + .map(|line| serde_json::from_str::(&line).unwrap()) + .collect(); + let filter = match JmespathFilter::from_filters(&vec![ + "!contains(keys(@), 'status') || (status == '1' && age >= `26`)".to_string(), + ]) { + Ok(f) => f, + Err(e) => panic!("{}", e), + }; + + let (passed_messages, filtered_messages) = run_filter(&filter, &values); + + assert_eq!(7, passed_messages); + assert_eq!(3, filtered_messages); + } +} diff --git a/src/filters/jmespath_filter/mod.rs b/src/filters/jmespath_filter/mod.rs new file mode 100644 index 0000000..c79dafe --- /dev/null +++ b/src/filters/jmespath_filter/mod.rs @@ -0,0 +1,2 @@ +mod custom_functions; +pub(super) mod filter; diff --git a/src/filters/mod.rs b/src/filters/mod.rs new file mode 100644 index 0000000..f1ea6b6 --- /dev/null +++ b/src/filters/mod.rs @@ -0,0 +1,11 @@ +pub use error::FilterError; +pub use filter::Filter; +pub use filter_factory::{FilterEngine, FilterFactory}; +pub(crate) use jmespath_filter::filter::JmespathFilter; +pub(crate) use naive_filter::filter::NaiveFilter; + +mod error; +mod filter; +mod filter_factory; +mod jmespath_filter; +mod naive_filter; diff --git a/src/filters/naive_filter/error.rs b/src/filters/naive_filter/error.rs new file mode 100644 index 0000000..5de1c23 --- /dev/null +++ b/src/filters/naive_filter/error.rs @@ -0,0 +1,18 @@ +#[derive(thiserror::Error, Debug)] +pub enum NaiveFilterError { + /// Error from [`serde_json`] + #[error("JSON serialization failed: {source}")] + SerdeJson { + /// Wrapped [`serde_json::Error`] + #[from] + source: serde_json::Error, + }, + + /// Error occurs when trying to execute a filter + #[error("NaiveFilter execution error: {reason}")] + RuntimeError { reason: String }, + + /// Error occurs when trying to prepare filters for execution + #[error("NaiveFilter prepare error: {reason}")] + PrepareError { reason: String }, +} diff --git a/src/filters/naive_filter/filter.rs b/src/filters/naive_filter/filter.rs new file mode 100644 index 0000000..d9136d8 --- /dev/null +++ b/src/filters/naive_filter/filter.rs @@ -0,0 +1,212 @@ +use regex::Regex; +use serde_json::Value; + +use crate::filters::filter::Filter; +use crate::filters::naive_filter::operand::NaiveFilterOperand; +use crate::filters::naive_filter::operator::{get_operator, OperatorRef}; +use crate::filters::FilterError; + +pub struct NaiveFilterExpression { + left: NaiveFilterOperand, + op: OperatorRef, + right: NaiveFilterOperand, +} + +/// Implementation of the [Filter] feature for simple comparison checks. +/// If a path was provided, it must always be present in the object. +/// Possible operations: >=, <=, ==, !=, ~=, >, < +/// ~= - case-insensitive comparison, for example: path.to.attr ~= 'VaLuE' +pub(crate) struct NaiveFilter { + filters: Vec, +} + +impl Filter for NaiveFilter { + fn from_filters(filters: &[String]) -> Result { + let mut expressions: Vec = Vec::new(); + let re = Regex::new(r"(?.*)(?>=|<=|==|!=|~=|>|<)(?.*)").unwrap(); + for filter in filters.iter() { + let (_, [left, op, right]) = re.captures(filter.trim()).unwrap().extract(); + expressions.push(NaiveFilterExpression { + left: NaiveFilterOperand::from_str(left)?, + op: get_operator(op)?, + right: NaiveFilterOperand::from_str(right)?, + }); + } + + Ok(NaiveFilter { + filters: expressions, + }) + } + + fn filter(&self, message: &Value) -> Result<(), FilterError> { + for filter in self.filters.iter() { + if !filter.op.execute( + filter.left.get_value(message), + filter.right.get_value(message), + )? { + return Err(FilterError::FilterSkipMessage); + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::fs::File; + use std::io; + use std::io::{BufRead, BufReader}; + + use super::*; + + const SOURCE_PATH: &str = "tests/json/web_requests-100.json"; + + fn read_json_file(file_path: &str) -> io::Result> { + let file = File::open(file_path)?; + let reader = BufReader::new(file); + let lines: Vec<_> = reader.lines().take(30000).collect::>()?; + + let values: Vec = lines + .iter() + .map(|line| serde_json::from_str::(&line).unwrap()) + .collect(); + + Ok(values) + } + + fn run_filter(filter: &NaiveFilter, values: &Vec) -> (i32, i32) { + let mut passed_messages = 0; + let mut filtered_messages = 0; + + for v in values.into_iter() { + match filter.filter(&v) { + Ok(_) => passed_messages += 1, + Err(FilterError::FilterSkipMessage) => filtered_messages += 1, + Err(e) => panic!("{}", e), + } + } + + return (passed_messages, filtered_messages); + } + #[test] + fn greater_than_or_equal() { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = match NaiveFilter::from_filters(&vec![ + "status>=`201`".to_string(), + "method=='GET'".to_string(), + ]) { + Ok(f) => f, + Err(e) => panic!("{}", e), + }; + let (passed_messages, filtered_messages) = run_filter(&filter, &values); + + assert_eq!(14, passed_messages); + assert_eq!(86, filtered_messages); + } + + #[test] + fn less_than_or_equal() { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = match NaiveFilter::from_filters(&vec![ + "status<=`403`".to_string(), + "method=='POST'".to_string(), + ]) { + Ok(f) => f, + Err(e) => panic!("{}", e), + }; + let (passed_messages, filtered_messages) = run_filter(&filter, &values); + + assert_eq!(12, passed_messages); + assert_eq!(88, filtered_messages); + } + + #[test] + fn equal() { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = match NaiveFilter::from_filters(&vec![ + "session_id=='a8a3d0e3-7b4e-4f17-b264-76cb792bdb96'".to_string(), + ]) { + Ok(f) => f, + Err(e) => panic!("{}", e), + }; + + let (passed_messages, filtered_messages) = run_filter(&filter, &values); + + assert_eq!(1, passed_messages); + assert_eq!(99, filtered_messages); + } + #[test] + fn not_equal() { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = match NaiveFilter::from_filters(&vec!["method!='POST'".to_string()]) { + Ok(f) => f, + Err(e) => panic!("{}", e), + }; + + let (passed_messages, filtered_messages) = run_filter(&filter, &values); + + assert_eq!(81, passed_messages); + assert_eq!(19, filtered_messages); + } + #[test] + fn eq_ignore_case() { + let values = read_json_file(SOURCE_PATH).unwrap(); + let filter = match NaiveFilter::from_filters(&vec!["method~='get'".to_string()]) { + Ok(f) => f, + Err(e) => panic!("{}", e), + }; + + let (passed_messages, filtered_messages) = run_filter(&filter, &values); + + assert_eq!(17, passed_messages); + assert_eq!(83, filtered_messages); + } + + #[test] + fn invalid_filters() { + assert!( + NaiveFilter::from_filters(&vec!["method~='get]".to_string()]).is_err(), + "The filter should not have been created" + ); + assert!( + NaiveFilter::from_filters(&vec!["method~='get']".to_string()]).is_err(), + "The filter should not have been created" + ); + assert!( + NaiveFilter::from_filters(&vec!["status~=`404".to_string()]).is_err(), + "The filter should not have been created" + ); + assert!( + NaiveFilter::from_filters(&vec!["status~=`404,123`".to_string()]).is_err(), + "The filter should not have been created" + ); + assert!( + NaiveFilter::from_filters(&vec!["status~=`abc`".to_string()]).is_err(), + "The filter should not have been created" + ); + assert!( + NaiveFilter::from_filters(&vec!["status~=`abc`".to_string()]).is_err(), + "The filter should not have been created" + ); + } + + #[test] + fn valid_filters() { + assert!( + NaiveFilter::from_filters(&vec!["method=='get'".to_string()]).is_ok(), + "The filter should have been created" + ); + assert!( + NaiveFilter::from_filters(&vec!["status==`404`".to_string()]).is_ok(), + "The filter should have been created" + ); + assert!( + NaiveFilter::from_filters(&vec!["status==internal.status".to_string()]).is_ok(), + "The filter should have been created" + ); + assert!( + NaiveFilter::from_filters(&vec!["internal.value!=`3.1415962`".to_string()]).is_ok(), + "The filter should have been created" + ); + } +} diff --git a/src/filters/naive_filter/mod.rs b/src/filters/naive_filter/mod.rs new file mode 100644 index 0000000..6554787 --- /dev/null +++ b/src/filters/naive_filter/mod.rs @@ -0,0 +1,4 @@ +pub(super) mod error; +pub(super) mod filter; +pub(super) mod operand; +pub(super) mod operator; diff --git a/src/filters/naive_filter/operand.rs b/src/filters/naive_filter/operand.rs new file mode 100644 index 0000000..61bec4d --- /dev/null +++ b/src/filters/naive_filter/operand.rs @@ -0,0 +1,74 @@ +use serde_json::{json, Value}; + +use crate::filters::naive_filter::error::NaiveFilterError; + +/// Container to store the path to the value or the value itself for later comparison +pub(super) struct NaiveFilterOperand { + pub value: Option, + pub path: Option>, +} + +impl NaiveFilterOperand { + fn new(value: Option, path: Option) -> Result { + if value.is_none() && path.is_none() { + return Err(NaiveFilterError::PrepareError { + reason: "Cannot create expression without path or value".to_string(), + }); + }; + + if value.is_some() { + return Ok(Self { value, path: None }); + } + + let path: Vec = path.unwrap().split('.').map(str::to_string).collect(); + Ok(Self { + value, + path: Some(path), + }) + } + + pub(crate) fn from_str(operand_str: &str) -> Result { + let operand_str = operand_str.trim(); + + match operand_str.chars().next() { + // number + Some('`') => { + if !operand_str.ends_with('`') { + return Err(NaiveFilterError::PrepareError { + reason: "To filter by number, the number must begin and end with `" + .to_string(), + }); + } + NaiveFilterOperand::new(serde_json::from_str(operand_str.trim_matches('`'))?, None) + } + // string + Some('\'') => { + if !operand_str.ends_with('\'') { + return Err(NaiveFilterError::PrepareError { + reason: "To filter by string, the string must begin and end with '" + .to_string(), + }); + } + NaiveFilterOperand::new(Some(json!(operand_str.trim_matches('\''))), None) + } + // path to attribute via dot + _ => NaiveFilterOperand::new(None, Some(operand_str.to_string())), + } + } + fn is_path(&self) -> bool { + self.path.is_some() + } + + pub(crate) fn get_value<'a>(&'a self, message: &'a Value) -> &Value { + return if self.is_path() { + let mut path_iter = self.path.as_ref().unwrap().iter(); + let mut cursor: &Value = &message[path_iter.next().unwrap()]; + for p in path_iter { + cursor = &cursor[p]; + } + return cursor; + } else { + self.value.as_ref().unwrap() + }; + } +} diff --git a/src/filters/naive_filter/operator.rs b/src/filters/naive_filter/operator.rs new file mode 100644 index 0000000..b5e183f --- /dev/null +++ b/src/filters/naive_filter/operator.rs @@ -0,0 +1,160 @@ +use std::sync::Arc; + +use serde_json::Value; + +use crate::filters::naive_filter::error::NaiveFilterError; + +struct GteOperator {} +struct LteOperator {} +struct EqOperator {} +struct NeqOperator {} +struct IeqOperator {} +struct GtOperator {} +struct LtOperator {} + +pub(crate) trait Operator: Send + Sync + 'static { + fn execute(&self, left: &Value, right: &Value) -> Result; +} + +impl Operator for GteOperator { + fn execute(&self, left: &Value, right: &Value) -> Result { + match left { + Value::Number(n) => { + if let Some(integer) = n.as_i64() { + Ok(integer >= right.as_i64().unwrap()) + } else { + Ok(n.as_f64().unwrap() >= right.as_f64().unwrap()) + } + }, + _ => Err( + NaiveFilterError::RuntimeError { + reason: format!("The >= operator can only be used for numbers (for example, `2` or `3.1415`, along with quotes). Passed: {:?}, {:?}", left, right) + } + ) + } + } +} + +impl Operator for LteOperator { + fn execute(&self, left: &Value, right: &Value) -> Result { + match left { + Value::Number(n) => { + if let Some(integer) = n.as_i64() { + Ok(integer <= right.as_i64().unwrap()) + } else { + Ok(n.as_f64().unwrap() <= right.as_f64().unwrap()) + } + }, + _ => Err( + NaiveFilterError::RuntimeError { + reason: format!("The <= operator can only be used for numbers (for example, `2` or `3.1415`, along with quotes). Passed: {:?}, {:?}", left, right) + } + ) + } + } +} +impl Operator for EqOperator { + fn execute(&self, left: &Value, right: &Value) -> Result { + match left { + Value::Number(n) => { + if let Some(integer) = n.as_i64() { + Ok(integer == right.as_i64().unwrap()) + } else { + Ok(n.as_f64().unwrap() == right.as_f64().unwrap()) + } + }, + Value::String(s) => { + return Ok(s.as_str() == right.as_str().unwrap()) + }, + Value::Bool(b) => Ok(*b == right.as_bool().unwrap()), + _ => Err( + NaiveFilterError::RuntimeError { + reason: format!("The == operator can only be used for numbers, strings or bools. Passed: {:?}, {:?}", left, right) + } + ) + } + } +} +impl Operator for NeqOperator { + fn execute(&self, left: &Value, right: &Value) -> Result { + match left { + Value::Number(n) => { + if let Some(integer) = n.as_i64() { + Ok(integer != right.as_i64().unwrap()) + } else { + Ok(n.as_f64().unwrap() != right.as_f64().unwrap()) + } + }, + Value::String(s) => { + Ok(s.as_str() != right.as_str().unwrap()) + }, + Value::Bool(b) => Ok(*b != right.as_bool().unwrap()), + _ => Err( + NaiveFilterError::RuntimeError { + reason: format!("The != operator can only be used for numbers, strings or bools. Passed: {:?}, {:?}", left, right) + } + ) + } + } +} +impl Operator for IeqOperator { + fn execute(&self, left: &Value, right: &Value) -> Result { + Ok(left + .as_str() + .unwrap() + .eq_ignore_ascii_case(right.as_str().unwrap())) + } +} +impl Operator for GtOperator { + fn execute(&self, left: &Value, right: &Value) -> Result { + match left { + Value::Number(n) => { + if let Some(integer) = n.as_i64() { + Ok(integer > right.as_i64().unwrap()) + } else { + Ok(n.as_f64().unwrap() > right.as_f64().unwrap()) + } + }, + _ => Err( + NaiveFilterError::RuntimeError { + reason: format!("The > operator can only be used for numbers (for example, `2` or `3.1415`, along with quotes). Passed: {:?}, {:?}", left, right) + } + ) + } + } +} +impl Operator for LtOperator { + fn execute(&self, left: &Value, right: &Value) -> Result { + match left { + Value::Number(n) => { + if let Some(integer) = n.as_i64() { + Ok(integer < right.as_i64().unwrap()) + } else { + Ok(n.as_f64().unwrap() < right.as_f64().unwrap()) + } + }, + _ => Err( + NaiveFilterError::RuntimeError { + reason: format!("The < operator can only be used for numbers (for example, `2` or `3.1415`, along with quotes). Passed: {:?}, {:?}", left, right) + } + ) + } + } +} + +pub(crate) type OperatorRef = Arc; + +pub(crate) fn get_operator(operator_str: &str) -> Result { + match operator_str { + ">=" => Ok(Arc::new(GteOperator {})), + "<=" => Ok(Arc::new(LteOperator {})), + "==" => Ok(Arc::new(EqOperator {})), + "!=" => Ok(Arc::new(NeqOperator {})), + "~=" => Ok(Arc::new(IeqOperator {})), + ">" => Ok(Arc::new(GtOperator {})), + "<" => Ok(Arc::new(LtOperator {})), + _ => Err(NaiveFilterError::RuntimeError { + reason: format!("There is no operand {}", operator_str), + }), + } +} diff --git a/src/lib.rs b/src/lib.rs index f1ab51a..fbf2d83 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,20 +9,23 @@ #[macro_use] extern crate lazy_static; - +#[cfg(test)] +extern crate serde_json; #[macro_use] extern crate strum_macros; -#[cfg(test)] -extern crate serde_json; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use std::{collections::HashMap, path::PathBuf}; -use coercions::CoercionTree; use deltalake_core::operations::transaction::TableReference; use deltalake_core::protocol::DeltaOperation; use deltalake_core::protocol::OutputMode; use deltalake_core::{DeltaTable, DeltaTableError}; use futures::stream::StreamExt; use log::{debug, error, info, warn}; +use rdkafka::message::BorrowedMessage; use rdkafka::{ config::ClientConfig, consumer::{Consumer, ConsumerContext, Rebalance, StreamConsumer}, @@ -31,19 +34,30 @@ use rdkafka::{ ClientContext, Message, Offset, TopicPartitionList, }; use serde_json::Value; -use serialization::{MessageDeserializer, MessageDeserializerFactory}; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use std::{collections::HashMap, path::PathBuf}; use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; use url::Url; +use coercions::CoercionTree; +use delta_helpers::*; +use serialization::{MessageDeserializer, MessageDeserializerFactory}; + +pub use crate::filters::{Filter, FilterEngine, FilterError, FilterFactory}; +use crate::offsets::WriteOffsetsError; +use crate::value_buffers::{ConsumedBuffers, ValueBuffers}; +use crate::{ + dead_letters::*, + metrics::*, + transforms::*, + writer::{DataWriter, DataWriterError}, +}; + mod coercions; /// Doc pub mod cursor; mod dead_letters; mod delta_helpers; +mod filters; mod metrics; mod offsets; mod serialization; @@ -52,18 +66,6 @@ mod value_buffers; /// Doc pub mod writer; -use crate::offsets::WriteOffsetsError; -use crate::value_buffers::{ConsumedBuffers, ValueBuffers}; -use crate::{ - dead_letters::*, - metrics::*, - transforms::*, - writer::{DataWriter, DataWriterError}, -}; -use delta_helpers::*; -use rdkafka::message::BorrowedMessage; -use std::ops::Add; - /// Type alias for Kafka partition pub type DataTypePartition = i32; /// Type alias for Kafka message offset @@ -205,6 +207,21 @@ pub enum IngestError { /// The underlying error. source: anyhow::Error, }, + + /// Errors returned by the filter + #[error("FilterError: {source}")] + Filter { + /// Wrapped [`FilterError`] + source: Box, + }, +} + +impl From for IngestError { + fn from(error: FilterError) -> Self { + IngestError::Filter { + source: Box::new(error), + } + } } /// Formats for message parsing @@ -280,6 +297,10 @@ pub struct IngestOptions { pub min_bytes_per_file: usize, /// A list of transforms to apply to the message before writing to delta lake. pub transforms: HashMap, + /// A list for filtering by message fields + pub filters: Vec, + /// Filter engine used + pub filter_engine: FilterEngine, /// An optional dead letter table to write messages that fail deserialization, transformation or schema validation. pub dlq_table_uri: Option, /// Transforms to apply to dead letters when writing to a delta table. @@ -310,6 +331,8 @@ impl Default for IngestOptions { max_messages_per_batch: 5000, min_bytes_per_file: 134217728, transforms: HashMap::new(), + filters: Vec::new(), + filter_engine: FilterEngine::Naive, dlq_table_uri: None, dlq_transforms: HashMap::new(), additional_kafka_settings: None, @@ -443,6 +466,13 @@ pub async fn start_ingest( debug!("Skipping message with partition {}, offset {} on topic {} because it was already processed", partition, offset, topic); continue; } + IngestError::Filter { source } => match *source { + FilterError::FilterSkipMessage => { + ingest_metrics.message_filtered(); + debug!("Skip message by filter"); + } + _ => return Err(IngestError::Filter { source }), + }, _ => return Err(e), } } @@ -734,6 +764,7 @@ struct IngestProcessor { topic: String, consumer: Arc>, transformer: Transformer, + filter: Box, coercion_tree: CoercionTree, table: DeltaTable, delta_writer: DataWriter, @@ -758,6 +789,7 @@ impl IngestProcessor { let dlq = dead_letter_queue_from_options(&opts).await?; let transformer = Transformer::from_transforms(&opts.transforms)?; let table = delta_helpers::load_table(table_uri, HashMap::new()).await?; + let filter = FilterFactory::try_build(&opts.filter_engine, &opts.filters)?; let coercion_tree = coercions::create_coercion_tree(table.schema().unwrap()); let delta_writer = DataWriter::for_table(&table, HashMap::new())?; let deserializer = @@ -770,6 +802,7 @@ impl IngestProcessor { topic, consumer, transformer, + filter, coercion_tree, table, delta_writer, @@ -820,6 +853,8 @@ impl IngestProcessor { // Deserialize match self.deserialize_message(&message).await { Ok(mut value) => { + self.filter.filter(&value)?; + self.ingest_metrics.message_deserialized(); // Transform match self.transformer.transform(&mut value, Some(&message)) { diff --git a/src/main.rs b/src/main.rs index 3420f28..be3fb9b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,18 +30,20 @@ #![deny(deprecated)] #![deny(missing_docs)] -use chrono::Local; -use clap::{Arg, ArgAction, ArgGroup, ArgMatches, Command}; -use kafka_delta_ingest::{ - start_ingest, AutoOffsetReset, DataTypeOffset, DataTypePartition, IngestOptions, MessageFormat, - SchemaSource, -}; -use log::{error, info, LevelFilter}; use std::collections::HashMap; use std::io::prelude::*; use std::path::PathBuf; use std::str::FromStr; +use chrono::Local; +use clap::{Arg, ArgAction, ArgGroup, ArgMatches, Command}; +use log::{error, info, LevelFilter}; + +use kafka_delta_ingest::{ + start_ingest, AutoOffsetReset, DataTypeOffset, DataTypePartition, FilterEngine, FilterError, + IngestOptions, MessageFormat, SchemaSource, +}; + #[tokio::main(flavor = "current_thread")] async fn main() -> anyhow::Result<()> { #[cfg(feature = "s3")] @@ -119,6 +121,13 @@ async fn main() -> anyhow::Result<()> { .map(|list| list.map(|t| parse_transform(t).unwrap()).collect()) .unwrap_or_else(HashMap::new); + let filters: Vec = ingest_matches + .get_many::("filter") + .map(|list| list.cloned().collect()) + .unwrap_or_else(Vec::new); + + let filter_engine: FilterEngine = convert_matches_to_filter_engine(ingest_matches)?; + let dlq_table_location = ingest_matches .get_one::("dlq_table_location") .map(|s| s.to_string()); @@ -156,6 +165,8 @@ async fn main() -> anyhow::Result<()> { max_messages_per_batch: *max_messages_per_batch, min_bytes_per_file: *min_bytes_per_file, transforms, + filters, + filter_engine, dlq_table_uri: dlq_table_location, dlq_transforms, write_checkpoints, @@ -402,6 +413,39 @@ the following well-known Kafka metadata properties: * kafka.topic * kafka.timestamp "#)) + .arg(Arg::new("filter") + .short('f') + .long("filter") + .action(ArgAction::Append) + .help( + r#"A list of filters that will be applied to each message before transforms. +Filters are separated by semicolons. There are two types of filter. +The first, naive filter, which is used by default, supports simple operations and a path flowing through points. +List of operations: ==, !=, >, <, >=, <=, ~=. The last one is case-insensitive string comparison. +For example: +-f "path.to.num.value >= `5`;string_value_key~='buzz'" + +The second jmespath-based filter is used for complex conditions, such as checking inside an array +or checking for the presence of a key. Due to its more complex functionality, it works slower +than the naive one, but is still quite fast. See: https://jmespath.org/tutorial.html +For example: +-f "!contains(keys(@), 'status') || (status == 'status' && factor >= `26`)" + +Strings must be enclosed in single quotes "'", numbers must be enclosed in backticks "`" +"#) + .env("FILTERS") + .num_args(0..) + .value_delimiter(';')) + .arg( + Arg::new("filter_engine") + .long("filter_engine") + .env("FILTER_ENGINE") + .value_parser(["naive", "jmespath"]) + .action(ArgAction::Set) + .default_value("naive") + .help("Naive for simple comparisons and quick work, jmespath for complex comparisons") + .required(false) + ) .arg(Arg::new("dlq_table_location") .long("dlq_table_location") .env("DLQ_TABLE_LOCATION") @@ -480,9 +524,26 @@ fn convert_matches_to_message_format( .map(MessageFormat::Json); } +fn convert_matches_to_filter_engine( + ingest_matches: &ArgMatches, +) -> Result> { + return match ingest_matches + .get_one::("filter_engine") + .unwrap() + .as_str() + { + "naive" => Ok(FilterEngine::Naive), + "jmespath" => Ok(FilterEngine::Jmespath), + f => Err(Box::new(FilterError::NotFound { + name: f.to_string(), + })), + }; +} + #[cfg(test)] mod test { use clap::ArgMatches; + use kafka_delta_ingest::{MessageFormat, SchemaSource}; use crate::{ diff --git a/src/metrics.rs b/src/metrics.rs index 3bbac91..70c467a 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,8 +1,9 @@ -use dipstick::*; -use log::error; use std::convert::TryInto; use std::time::Instant; +use dipstick::*; +use log::error; + /// The environment variable used to specify how many metrics should be written to the metrics queue before flushing to statsd. const METRICS_INPUT_QUEUE_SIZE_VAR_NAME: &str = "KDI_METRICS_INPUT_QUEUE_SIZE"; /// The environment variable used to specify a prefix for metrics. @@ -58,6 +59,11 @@ impl IngestMetrics { self.record_one(StatType::MessageTransformFailed); } + /// increments a counter for message filtered + pub fn message_filtered(&self) { + self.record_one(StatType::MessageFiltered); + } + /// increments a counter for record batch started pub fn batch_started(&self) { self.record_one(StatType::RecordBatchStarted); @@ -236,6 +242,9 @@ enum StatType { /// Counter for a message that failed transformation. #[strum(serialize = "messages.transform.failed")] MessageTransformFailed, + /// Counter for a message that skipped. + #[strum(serialize = "messages.filter.filtered")] + MessageFiltered, /// Counter for when a record batch is started. #[strum(serialize = "recordbatch.started")] RecordBatchStarted,