diff --git a/Cargo.lock b/Cargo.lock index 5dbd3d26..39357a49 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2737,6 +2737,7 @@ dependencies = [ "log", "optd-datafusion-bridge", "optd-datafusion-repr", + "optd-gungnir", "regex", "shlex", "tokio", diff --git a/optd-datafusion-repr/src/cost/adaptive_cost.rs b/optd-datafusion-repr/src/cost/adaptive_cost.rs index a0b07309..35625782 100644 --- a/optd-datafusion-repr/src/cost/adaptive_cost.rs +++ b/optd-datafusion-repr/src/cost/adaptive_cost.rs @@ -51,15 +51,13 @@ impl CostModel for AdaptiveCostModel { ) -> Cost { if let OptRelNodeTyp::PhysicalScan = node { let guard = self.runtime_row_cnt.lock().unwrap(); - if let Some((runtime_row_cnt, iter)) = guard.history.get(&context.unwrap().group_id) { + if let Some((runtime_row_cnt, iter)) = + guard.history.get(&context.as_ref().unwrap().group_id) + { if *iter + self.decay >= guard.iter_cnt { let runtime_row_cnt = (*runtime_row_cnt).max(1) as f64; return OptCostModel::cost(runtime_row_cnt, 0.0, runtime_row_cnt); - } else { - return OptCostModel::cost(1.0, 0.0, 1.0); } - } else { - return OptCostModel::cost(1.0, 0.0, 1.0); } } let (mut row_cnt, compute_cost, io_cost) = OptCostModel::cost_tuple( diff --git a/optd-datafusion-repr/src/cost/base_cost.rs b/optd-datafusion-repr/src/cost/base_cost.rs index f08c7b6f..3f5a6de8 100644 --- a/optd-datafusion-repr/src/cost/base_cost.rs +++ b/optd-datafusion-repr/src/cost/base_cost.rs @@ -6,12 +6,20 @@ use crate::{ plan_nodes::{OptRelNodeRef, OptRelNodeTyp}, properties::column_ref::ColumnRef, }; +use arrow_schema::{ArrowError, DataType}; +use datafusion::arrow::array::{ + Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, UInt16Array, + UInt32Array, UInt8Array, +}; use itertools::Itertools; use optd_core::{ cascades::{CascadesOptimizer, RelNodeContext}, cost::{Cost, CostModel}, rel_node::{RelNode, RelNodeTyp, Value}, }; +use optd_gungnir::stats::hyperloglog::{self, HyperLogLog}; +use optd_gungnir::stats::tdigest::{self, TDigest}; fn compute_plan_node_cost>( model: &C, @@ -34,9 +42,207 @@ pub struct OptCostModel { per_table_stats_map: BaseTableStats, } +struct MockMostCommonValues { + mcvs: HashMap, +} + +impl MockMostCommonValues { + pub fn empty() -> Self { + MockMostCommonValues { + mcvs: HashMap::new(), + } + } +} + +impl MostCommonValues for MockMostCommonValues { + fn freq(&self, value: &Value) -> Option { + self.mcvs.get(value).copied() + } + + fn total_freq(&self) -> f64 { + self.mcvs.values().sum() + } + + fn freq_over_pred(&self, pred: Box bool>) -> f64 { + self.mcvs + .iter() + .filter(|(val, _)| pred(val)) + .map(|(_, freq)| freq) + .sum() + } + + fn cnt(&self) -> usize { + self.mcvs.len() + } +} + pub struct PerTableStats { row_cnt: usize, - per_column_stats_vec: Vec, + per_column_stats_vec: Vec>, +} + +impl PerTableStats { + pub fn from_record_batches>>( + batch_iter: RecordBatchIterator, + ) -> anyhow::Result { + let schema = batch_iter.schema(); + let col_types = schema + .fields() + .iter() + .map(|f| f.data_type().clone()) + .collect_vec(); + let col_cnt = col_types.len(); + + let mut row_cnt = 0; + let mut mcvs = col_types + .iter() + .map(|col_type| { + if Self::is_type_supported(col_type) { + Some(MockMostCommonValues::empty()) + } else { + None + } + }) + .collect_vec(); + let mut distr = col_types + .iter() + .map(|col_type| { + if Self::is_type_supported(col_type) { + Some(TDigest::new(tdigest::DEFAULT_COMPRESSION)) + } else { + None + } + }) + .collect_vec(); + let mut hlls = vec![HyperLogLog::new(hyperloglog::DEFAULT_PRECISION); col_cnt]; + let mut null_cnt = vec![0; col_cnt]; + + for batch in batch_iter { + let batch = batch?; + row_cnt += batch.num_rows(); + + // Enumerate the columns. + for (i, col) in batch.columns().iter().enumerate() { + let col_type = &col_types[i]; + if Self::is_type_supported(col_type) { + // Update null cnt. + null_cnt[i] += col.null_count(); + + Self::generate_stats_for_column(col, col_type, &mut distr[i], &mut hlls[i]); + } + } + } + + // Assemble the per-column stats. + let mut per_column_stats_vec = Vec::with_capacity(col_cnt); + for i in 0..col_cnt { + per_column_stats_vec.push(if Self::is_type_supported(&col_types[i]) { + Some(PerColumnStats { + mcvs: Box::new(mcvs[i].take().unwrap()) as Box, + ndistinct: hlls[i].n_distinct(), + null_frac: null_cnt[i] as f64 / row_cnt as f64, + distr: Box::new(distr[i].take().unwrap()) as Box, + }) + } else { + None + }); + } + Ok(Self { + row_cnt, + per_column_stats_vec, + }) + } + + fn is_type_supported(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::Float32 + | DataType::Float64 + ) + } + + /// Generate statistics for a column. + fn generate_stats_for_column( + col: &Arc, + col_type: &DataType, + distr: &mut Option, + hll: &mut HyperLogLog, + ) { + macro_rules! generate_stats_for_col { + ({ $col:expr, $distr:expr, $hll:expr, $array_type:path, $to_f64:ident }) => {{ + let array = $col.as_any().downcast_ref::<$array_type>().unwrap(); + // Filter out `None` values. + let values = array.iter().filter_map(|x| x).collect::>(); + + // Update distribution. + *$distr = { + let mut f64_values = values.iter().map(|x| $to_f64(*x)).collect::>(); + Some($distr.take().unwrap().merge_values(&mut f64_values)) + }; + + // Update hll. + $hll.aggregate(&values); + }}; + } + + /// Convert a value to f64 with no out of range or precision loss. + fn to_f64_safe>(val: T) -> f64 { + val.into() + } + + /// Convert i128 to f64 with possible precision loss. + /// + /// Note: optd represents decimal with the significand as f64 (see `ConstantExpr::decimal`). + /// For instance 0.04 of type `Decimal128(15, 2)` is just 4.0, the type information + /// is discarded. Therefore we must use the significand to generate the statistics. + fn i128_to_f64(val: i128) -> f64 { + val as f64 + } + + match col_type { + DataType::Boolean => { + generate_stats_for_col!({ col, distr, hll, BooleanArray, to_f64_safe }) + } + DataType::Int8 => { + generate_stats_for_col!({ col, distr, hll, Int8Array, to_f64_safe }) + } + DataType::Int16 => { + generate_stats_for_col!({ col, distr, hll, Int16Array, to_f64_safe }) + } + DataType::Int32 => { + generate_stats_for_col!({ col, distr, hll, Int32Array, to_f64_safe }) + } + DataType::UInt8 => { + generate_stats_for_col!({ col, distr, hll, UInt8Array, to_f64_safe }) + } + DataType::UInt16 => { + generate_stats_for_col!({ col, distr, hll, UInt16Array, to_f64_safe }) + } + DataType::UInt32 => { + generate_stats_for_col!({ col, distr, hll, UInt32Array, to_f64_safe }) + } + DataType::Float32 => { + generate_stats_for_col!({ col, distr, hll, Float32Array, to_f64_safe }) + } + DataType::Float64 => { + generate_stats_for_col!({ col, distr, hll, Float64Array, to_f64_safe }) + } + DataType::Date32 => { + generate_stats_for_col!({ col, distr, hll, Date32Array, to_f64_safe }) + } + DataType::Decimal128(_, _) => { + generate_stats_for_col!({ col, distr, hll, Decimal128Array, i128_to_f64 }) + } + _ => unreachable!(), + } + } } pub struct PerColumnStats { @@ -45,7 +251,7 @@ pub struct PerColumnStats { // ndistinct _does_ include the values in mcvs // ndistinct _does not_ include nulls - ndistinct: i32, + ndistinct: u64, // postgres uses null_frac instead of something like "num_nulls" so we'll follow suit // my guess for why they use null_frac is because we only ever use the fraction of nulls, not the # @@ -445,7 +651,8 @@ impl OptCostModel { is_eq: bool, ) -> Option { if let Some(per_table_stats) = self.per_table_stats_map.get(table) { - if let Some(per_column_stats) = per_table_stats.per_column_stats_vec.get(col_idx) { + if let Some(Some(per_column_stats)) = per_table_stats.per_column_stats_vec.get(col_idx) + { let eq_freq = if let Some(freq) = per_column_stats.mcvs.freq(value) { freq } else { @@ -484,7 +691,8 @@ impl OptCostModel { is_col_eq_val: bool, ) -> Option { if let Some(per_table_stats) = self.per_table_stats_map.get(table) { - if let Some(per_column_stats) = per_table_stats.per_column_stats_vec.get(col_idx) { + if let Some(Some(per_column_stats)) = per_table_stats.per_column_stats_vec.get(col_idx) + { // because distr does not include the values in MCVs, we need to compute the CDFs there as well // because nulls return false in any comparison, they are never included when computing range selectivity let distr_leq_freq = per_column_stats.distr.cdf(value); @@ -555,7 +763,7 @@ impl OptCostModel { } impl PerTableStats { - pub fn new(row_cnt: usize, per_column_stats_vec: Vec) -> Self { + pub fn new(row_cnt: usize, per_column_stats_vec: Vec>) -> Self { Self { row_cnt, per_column_stats_vec, @@ -566,7 +774,7 @@ impl PerTableStats { impl PerColumnStats { pub fn new( mcvs: Box, - ndistinct: i32, + ndistinct: u64, null_frac: f64, distr: Box, ) -> Self { @@ -612,7 +820,7 @@ mod tests { } } - fn empty() -> Self { + pub fn empty() -> Self { MockMostCommonValues::new(vec![]) } } @@ -664,7 +872,7 @@ mod tests { OptCostModel::new( vec![( String::from(TABLE1_NAME), - PerTableStats::new(100, vec![per_column_stats]), + PerTableStats::new(100, vec![Some(per_column_stats)]), )] .into_iter() .collect(), diff --git a/optd-gungnir/src/stats/hyperloglog.rs b/optd-gungnir/src/stats/hyperloglog.rs index 584597f8..aca39eb2 100644 --- a/optd-gungnir/src/stats/hyperloglog.rs +++ b/optd-gungnir/src/stats/hyperloglog.rs @@ -8,6 +8,8 @@ use crate::stats::murmur2::murmur_hash; use std::cmp::max; +pub const DEFAULT_PRECISION: u8 = 12; + /// Trait to transform any object into a stream of bytes. pub trait ByteSerializable { fn to_bytes(&self) -> Vec; @@ -15,6 +17,7 @@ pub trait ByteSerializable { /// The HyperLogLog (HLL) structure to provide a statistical estimate of NDistinct. /// For safety reasons, HLLs can only count elements of the same ByteSerializable type. +#[derive(Clone)] pub struct HyperLogLog { registers: Vec, // The buckets to estimate HLL on (i.e. upper p bits). precision: u8, // The precision (p) of our HLL; 4 <= p <= 16. @@ -29,6 +32,13 @@ impl ByteSerializable for String { } } +// Serialize common data types for hashing (bool). +impl ByteSerializable for bool { + fn to_bytes(&self) -> Vec { + (*self as u8).to_bytes() + } +} + // Serialize common data types for hashing (numeric). macro_rules! impl_byte_serializable_for_numeric { ($($type:ty),*) => { diff --git a/optd-gungnir/src/stats/tdigest.rs b/optd-gungnir/src/stats/tdigest.rs index 7fe99536..7f24d08c 100644 --- a/optd-gungnir/src/stats/tdigest.rs +++ b/optd-gungnir/src/stats/tdigest.rs @@ -6,7 +6,10 @@ use itertools::Itertools; use std::f64::consts::PI; +pub const DEFAULT_COMPRESSION: f64 = 200.0; + /// The TDigest structure for the statistical aggregator to query quantiles. +#[derive(Clone)] pub struct TDigest { /// A sorted array of Centroids, according to their mean. centroids: Vec, diff --git a/optd-perftest/Cargo.toml b/optd-perftest/Cargo.toml index a7891d68..b9837918 100644 --- a/optd-perftest/Cargo.toml +++ b/optd-perftest/Cargo.toml @@ -16,6 +16,7 @@ datafusion = { version = "32.0.0", features = [ ] } optd-datafusion-repr = { path = "../optd-datafusion-repr" } optd-datafusion-bridge = { path = "../optd-datafusion-bridge" } +optd-gungnir = { path = "../optd-gungnir" } datafusion-optd-cli = { path = "../datafusion-optd-cli" } futures = "0.3" anyhow = { version = "1", features = ["backtrace"] } diff --git a/optd-perftest/src/cardtest.rs b/optd-perftest/src/cardtest.rs index 9ec19f61..5715dba7 100644 --- a/optd-perftest/src/cardtest.rs +++ b/optd-perftest/src/cardtest.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::path::Path; use crate::postgres_db::PostgresDb; -use crate::{benchmark::Benchmark, datafusion_db_cardtest::DatafusionDb, tpch::TpchConfig}; +use crate::{benchmark::Benchmark, datafusion_db::DatafusionDb, tpch::TpchConfig}; use anyhow::{self}; use async_trait::async_trait; diff --git a/optd-perftest/src/datafusion_db_cardtest.rs b/optd-perftest/src/datafusion_db.rs similarity index 66% rename from optd-perftest/src/datafusion_db_cardtest.rs rename to optd-perftest/src/datafusion_db.rs index e4b66a52..95d0da6f 100644 --- a/optd-perftest/src/datafusion_db_cardtest.rs +++ b/optd-perftest/src/datafusion_db.rs @@ -11,7 +11,11 @@ use crate::{ }; use async_trait::async_trait; use datafusion::{ - arrow::util::display::{ArrayFormatter, FormatOptions}, + arrow::{ + array::RecordBatchIterator, + csv::ReaderBuilder, + util::display::{ArrayFormatter, FormatOptions}, + }, execution::{ config::SessionConfig, context::{SessionContext, SessionState}, @@ -22,7 +26,7 @@ use datafusion::{ use datafusion_optd_cli::helper::unescape_input; use lazy_static::lazy_static; use optd_datafusion_bridge::{DatafusionCatalog, OptdQueryPlanner}; -use optd_datafusion_repr::{cost::BaseTableStats, DatafusionOptimizer}; +use optd_datafusion_repr::{cost::BaseTableStats, cost::PerTableStats, DatafusionOptimizer}; use regex::Regex; pub struct DatafusionDb { @@ -40,7 +44,6 @@ impl CardtestRunnerDBHelper for DatafusionDb { &mut self, benchmark: &Benchmark, ) -> anyhow::Result> { - self.clear_state().await?; self.load_benchmark_data(benchmark).await?; match benchmark { Benchmark::Test => unimplemented!(), @@ -52,7 +55,6 @@ impl CardtestRunnerDBHelper for DatafusionDb { &mut self, benchmark: &Benchmark, ) -> anyhow::Result> { - self.clear_state().await?; self.load_benchmark_data(benchmark).await?; match benchmark { Benchmark::Test => unimplemented!(), @@ -65,17 +67,21 @@ impl DatafusionDb { pub async fn new>(workspace_dpath: P) -> anyhow::Result { Ok(DatafusionDb { workspace_dpath: workspace_dpath.as_ref().to_path_buf(), - ctx: Self::new_session_ctx().await?, + ctx: Self::new_session_ctx(None).await?, }) } - /// Reset data and metadata. - async fn clear_state(&mut self) -> anyhow::Result<()> { - self.ctx = Self::new_session_ctx().await?; + /// Reset [`SessionContext`] to a clean state. But initializa the optimizer + /// with pre-generated statistics. + /// + /// A more ideal way to generate statistics would be to use the `ANALYZE` + /// command in SQL, but DataFusion does not support that yet. + async fn clear_state(&mut self, stats: Option) -> anyhow::Result<()> { + self.ctx = Self::new_session_ctx(stats).await?; Ok(()) } - async fn new_session_ctx() -> anyhow::Result { + async fn new_session_ctx(stats: Option) -> anyhow::Result { let session_config = SessionConfig::from_env()?.with_information_schema(true); let rn_config = RuntimeConfig::new(); let runtime_env = RuntimeEnv::new(rn_config.clone())?; @@ -84,7 +90,7 @@ impl DatafusionDb { SessionState::new_with_config_rt(session_config.clone(), Arc::new(runtime_env)); let optimizer: DatafusionOptimizer = DatafusionOptimizer::new_physical( Arc::new(DatafusionCatalog::new(state.catalog_list())), - BaseTableStats::default(), + stats.unwrap_or_default(), true, ); state = state.with_physical_optimizer_rules(vec![]); @@ -95,15 +101,15 @@ impl DatafusionDb { Ok(ctx) } - async fn execute(&self, sql: &str) -> anyhow::Result>> { + async fn execute(ctx: &SessionContext, sql: &str) -> anyhow::Result>> { let sql = unescape_input(sql)?; let dialect = Box::new(GenericDialect); let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; let mut result = Vec::new(); for statement in statements { let df = { - let plan = self.ctx.state().statement_to_plan(statement).await?; - self.ctx.execute_logical_plan(plan).await? + let plan = ctx.state().statement_to_plan(statement).await?; + ctx.execute_logical_plan(plan).await? }; let batches = df.collect().await?; @@ -159,7 +165,7 @@ impl DatafusionDb { } async fn eval_query_truecard(&self, sql: &str) -> anyhow::Result { - let rows = self.execute(sql).await?; + let rows = Self::execute(&self.ctx, sql).await?; let num_rows = rows.len(); Ok(num_rows) } @@ -168,7 +174,7 @@ impl DatafusionDb { lazy_static! { static ref ROW_CNT_RE: Regex = Regex::new(r"row_cnt=(\d+\.\d+)").unwrap(); } - let explains = self.execute(&format!("explain verbose {}", sql)).await?; + let explains = Self::execute(&self.ctx, &format!("explain verbose {}", sql)).await?; // Find first occurrence of row_cnt=... in the output. let row_cnt = explains .iter() @@ -195,9 +201,14 @@ impl DatafusionDb { } async fn load_tpch_data(&mut self, tpch_config: &TpchConfig) -> anyhow::Result<()> { + // Geenrate the tables. let tpch_kit = TpchKit::build(&self.workspace_dpath)?; tpch_kit.gen_tables(tpch_config)?; + // Generate the stats. + let stats = self.load_tpch_stats(&tpch_kit, tpch_config).await?; + self.clear_state(Some(stats)).await?; + // Create the tables. let ddls = fs::read_to_string(&tpch_kit.schema_fpath)?; let ddls = ddls @@ -206,18 +217,21 @@ impl DatafusionDb { .filter(|s| !s.is_empty()) .collect::>(); for ddl in ddls { - self.execute(ddl).await?; + Self::execute(&self.ctx, ddl).await?; } // Load the data by creating an external table first and copying the data to real tables. let tbl_fpath_iter = tpch_kit.get_tbl_fpath_iter(tpch_config).unwrap(); for tbl_fpath in tbl_fpath_iter { let tbl_name = tbl_fpath.file_stem().unwrap().to_str().unwrap(); - self.execute(&format!( - "create external table {}_tbl stored as csv delimiter '|' location '{}';", - tbl_name, - tbl_fpath.to_str().unwrap() - )) + Self::execute( + &self.ctx, + &format!( + "create external table {}_tbl stored as csv delimiter '|' location '{}';", + tbl_name, + tbl_fpath.to_str().unwrap() + ), + ) .await?; // Get the number of columns of this table. @@ -235,14 +249,63 @@ impl DatafusionDb { .map(|i| format!("column_{}", i)) .collect::>() .join(", "); - self.execute(&format!( - "insert into {} select {} from {}_tbl;", - tbl_name, projection_list, tbl_name, - )) + Self::execute( + &self.ctx, + &format!( + "insert into {} select {} from {}_tbl;", + tbl_name, projection_list, tbl_name, + ), + ) .await?; } + Ok(()) } + + async fn load_tpch_stats( + &self, + tpch_kit: &TpchKit, + tpch_config: &TpchConfig, + ) -> anyhow::Result { + // To get the schema of each table. + let ctx = Self::new_session_ctx(None).await?; + let ddls = fs::read_to_string(&tpch_kit.schema_fpath)?; + let ddls = ddls + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect::>(); + for ddl in ddls { + Self::execute(&ctx, ddl).await?; + } + let mut base_table_stats = BaseTableStats::default(); + for tbl_fpath in tpch_kit.get_tbl_fpath_iter(tpch_config).unwrap() { + let tbl_name = tbl_fpath.file_stem().unwrap().to_str().unwrap(); + let schema = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table(tbl_name) + .await + .unwrap() + .schema(); + // Load the .tbl file into record batches using arrow. + let tbl_file = fs::File::open(&tbl_fpath)?; + let csv_reader = ReaderBuilder::new(schema.clone()) + .has_header(false) + .with_delimiter(b'|') + .build(tbl_file) + .unwrap(); + let batch_iter = RecordBatchIterator::new(csv_reader, schema); + base_table_stats.insert( + tbl_name.to_string(), + PerTableStats::from_record_batches(batch_iter)?, + ); + log::debug!("statistics generated for table: {}", tbl_name); + } + Ok(base_table_stats) + } } unsafe impl Send for DatafusionDb {} diff --git a/optd-perftest/src/lib.rs b/optd-perftest/src/lib.rs index 56a6557c..96da0d56 100644 --- a/optd-perftest/src/lib.rs +++ b/optd-perftest/src/lib.rs @@ -1,6 +1,6 @@ mod benchmark; pub mod cardtest; -mod datafusion_db_cardtest; +mod datafusion_db; mod postgres_db; pub mod shell; pub mod tpch;