From 1c557a418cbdd991390d193ebd515bc5c06c7c3a Mon Sep 17 00:00:00 2001 From: Zhidong Guo <52783948+Gun9niR@users.noreply.github.com> Date: Wed, 20 Mar 2024 11:24:32 -0400 Subject: [PATCH] feat: integrate stats in to optd (#117) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generate the statistics in perftest and put them into `BaseCostModel` in `DatafusionOptimizer`. Below is the comparison before & after stats are added. You can check `PhysicalScan`, where the cost has changed. The final cardinality remains the same because when stats on a column is missing, we use a very small magic number `INVALID_SELECTIVITY` (0.001) that just sets cardinality to 1. ### Todos in Future PRs - Support generating stats on `Utf8`. - Set a better magic number. - Generate MCV. ### Before ``` plan space size budget used, not applying logical rules any more. current plan space: 1094 explain: PhysicalSort ├── exprs:SortOrder { order: Desc } │ └── #1 ├── cost: weighted=185.17,row_cnt=1.00,compute=179.17,io=6.00 └── PhysicalProjection { exprs: [ #0, #1 ], cost: weighted=182.12,row_cnt=1.00,compute=176.12,io=6.00 } └── PhysicalAgg ├── aggrs:Agg(Sum) │ └── Mul │ ├── #0 │ └── Sub │ ├── 1 │ └── #1 ├── groups: [ #2 ] ├── cost: weighted=182.02,row_cnt=1.00,compute=176.02,io=6.00 └── PhysicalProjection { exprs: [ #0, #1, #2 ], cost: weighted=64.90,row_cnt=1.00,compute=58.90,io=6.00 } └── PhysicalProjection { exprs: [ #0, #1, #4, #5, #6 ], cost: weighted=64.76,row_cnt=1.00,compute=58.76,io=6.00 } └── PhysicalProjection { exprs: [ #2, #3, #5, #6, #7, #8, #9 ], cost: weighted=64.54,row_cnt=1.00,compute=58.54,io=6.00 } └── PhysicalProjection { exprs: [ #0, #3, #4, #5, #6, #7, #8, #9, #10, #11 ], cost: weighted=64.24,row_cnt=1.00,compute=58.24,io=6.00 } └── PhysicalProjection { exprs: [ #1, #2, #4, #5, #6, #7, #8, #9, #10, #11, #12, #13 ], cost: weighted=63.82,row_cnt=1.00,compute=57.82,io=6.00 } └── PhysicalProjection { exprs: [ #0, #3, #8, #9, #10, #11, #12, #13, #14, #15, #16, #17, #18, #19 ], cost: weighted=63.32,row_cnt=1.00,compute=57.32,io=6.00 } └── PhysicalNestedLoopJoin ├── join_type: Inner ├── cond:And │ ├── Eq │ │ ├── #11 │ │ └── #14 │ └── Eq │ ├── #3 │ └── #15 ├── cost: weighted=62.74,row_cnt=1.00,compute=56.74,io=6.00 ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #1 ], cost: weighted=35.70,row_cnt=1.00,compute=32.70,io=3.00 } │ ├── PhysicalScan { table: customer, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: weighted=31.64,row_cnt=1.00,compute=29.64,io=2.00 } │ ├── PhysicalProjection { exprs: [ #0, #1 ], cost: weighted=27.40,row_cnt=1.00,compute=26.40,io=1.00 } │ │ └── PhysicalFilter │ │ ├── cond:And │ │ │ ├── Geq │ │ │ │ ├── #2 │ │ │ │ └── 9131 │ │ │ └── Lt │ │ │ ├── #2 │ │ │ └── 9496 │ │ ├── cost: weighted=27.30,row_cnt=1.00,compute=26.30,io=1.00 │ │ └── PhysicalProjection { exprs: [ #0, #1, #4 ], cost: weighted=1.14,row_cnt=1.00,compute=0.14,io=1.00 } │ │ └── PhysicalScan { table: orders, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } │ └── PhysicalProjection { exprs: [ #0, #2, #5, #6 ], cost: weighted=1.18,row_cnt=1.00,compute=0.18,io=1.00 } │ └── PhysicalScan { table: lineitem, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } └── PhysicalProjection { exprs: [ #0, #3, #7, #8, #9, #10 ], cost: weighted=15.72,row_cnt=1.00,compute=12.72,io=3.00 } └── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ], cost: weighted=15.46,row_cnt=1.00,compute=12.46,io=3.00 } ├── PhysicalScan { table: supplier, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2 ], right_keys: [ #0 ], cost: weighted=11.40,row_cnt=1.00,compute=9.40,io=2.00 } ├── PhysicalProjection { exprs: [ #0, #1, #2 ], cost: weighted=1.14,row_cnt=1.00,compute=0.14,io=1.00 } │ └── PhysicalScan { table: nation, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } └── PhysicalProjection { exprs: [ #0 ], cost: weighted=7.20,row_cnt=1.00,compute=6.20,io=1.00 } └── PhysicalFilter ├── cond:Eq │ ├── #1 │ └── "AMERICA" ├── cost: weighted=7.14,row_cnt=1.00,compute=6.14,io=1.00 └── PhysicalProjection { exprs: [ #0, #1 ], cost: weighted=1.10,row_cnt=1.00,compute=0.10,io=1.00 } └── PhysicalScan { table: region, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } plan space size budget used, not applying logical rules any more. current plan space: 1094 qerrors: {"DataFusion": [5.0]} ``` ### After ``` plan space size budget used, not applying logical rules any more. current plan space: 1094 explain: PhysicalSort ├── exprs:SortOrder { order: Desc } │ └── #1 ├── cost: weighted=336032.32,row_cnt=1.00,compute=259227.32,io=76805.00 └── PhysicalProjection { exprs: [ #0, #1 ], cost: weighted=336029.27,row_cnt=1.00,compute=259224.27,io=76805.00 } └── PhysicalAgg ├── aggrs:Agg(Sum) │ └── Mul │ ├── #0 │ └── Sub │ ├── 1 │ └── #1 ├── groups: [ #2 ] ├── cost: weighted=336029.17,row_cnt=1.00,compute=259224.17,io=76805.00 └── PhysicalProjection { exprs: [ #0, #1, #2 ], cost: weighted=335912.05,row_cnt=1.00,compute=259107.05,io=76805.00 } └── PhysicalProjection { exprs: [ #0, #1, #4, #5, #6 ], cost: weighted=335911.91,row_cnt=1.00,compute=259106.91,io=76805.00 } └── PhysicalProjection { exprs: [ #2, #3, #5, #6, #7, #8, #9 ], cost: weighted=335911.69,row_cnt=1.00,compute=259106.69,io=76805.00 } └── PhysicalProjection { exprs: [ #0, #3, #4, #5, #6, #7, #8, #9, #10, #11 ], cost: weighted=335911.39,row_cnt=1.00,compute=259106.39,io=76805.00 } └── PhysicalProjection { exprs: [ #1, #2, #4, #5, #6, #7, #8, #9, #10, #11, #12, #13 ], cost: weighted=335910.97,row_cnt=1.00,compute=259105.97,io=76805.00 } └── PhysicalProjection { exprs: [ #0, #3, #8, #9, #10, #11, #12, #13, #14, #15, #16, #17, #18, #19 ], cost: weighted=335910.47,row_cnt=1.00,compute=259105.47,io=76805.00 } └── PhysicalNestedLoopJoin ├── join_type: Inner ├── cond:And │ ├── Eq │ │ ├── #11 │ │ └── #14 │ └── Eq │ ├── #3 │ └── #15 ├── cost: weighted=335909.89,row_cnt=1.00,compute=259104.89,io=76805.00 ├── PhysicalProjection { exprs: [ #6, #7, #8, #9, #10, #11, #12, #13, #0, #1, #2, #3, #4, #5 ], cost: weighted=335619.21,row_cnt=1.00,compute=258944.21,io=76675.00 } │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ], cost: weighted=335618.63,row_cnt=1.00,compute=258943.63,io=76675.00 } │ ├── PhysicalProjection { exprs: [ #4, #5, #0, #1, #2, #3 ], cost: weighted=332616.57,row_cnt=1.00,compute=257441.57,io=75175.00 } │ │ └── PhysicalProjection { exprs: [ #0, #2, #5, #6, #16, #17 ], cost: weighted=332616.31,row_cnt=1.00,compute=257441.31,io=75175.00 } │ │ └── PhysicalProjection { exprs: [ #2, #3, #4, #5, #6, #7, #8, #9, #10, #11, #12, #13, #14, #15, #16, #17, #0, #1 ], cost: weighted=332616.05,row_cnt=1.00,compute=257441.05,io=75175.00 } │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: weighted=332615.31,row_cnt=1.00,compute=257440.31,io=75175.00 } │ │ ├── PhysicalProjection { exprs: [ #0, #1 ], cost: weighted=212263.25,row_cnt=1.00,compute=197263.25,io=15000.00 } │ │ │ └── PhysicalFilter │ │ │ ├── cond:And │ │ │ │ ├── Geq │ │ │ │ │ ├── #2 │ │ │ │ │ └── 9131 │ │ │ │ └── Lt │ │ │ │ ├── #2 │ │ │ │ └── 9496 │ │ │ ├── cost: weighted=212263.15,row_cnt=1.00,compute=197263.15,io=15000.00 │ │ │ └── PhysicalProjection { exprs: [ #0, #1, #4 ], cost: weighted=16050.07,row_cnt=15000.00,compute=1050.07,io=15000.00 } │ │ │ └── PhysicalScan { table: orders, cost: weighted=15000.00,row_cnt=15000.00,compute=0.00,io=15000.00 } │ │ └── PhysicalScan { table: lineitem, cost: weighted=60175.00,row_cnt=60175.00,compute=0.00,io=60175.00 } │ └── PhysicalScan { table: customer, cost: weighted=1500.00,row_cnt=1500.00,compute=0.00,io=1500.00 } └── PhysicalProjection { exprs: [ #0, #3, #7, #8, #9, #10 ], cost: weighted=279.36,row_cnt=1.00,compute=149.36,io=130.00 } └── PhysicalProjection { exprs: [ #4, #5, #6, #7, #8, #9, #10, #0, #1, #2, #3 ], cost: weighted=279.10,row_cnt=1.00,compute=149.10,io=130.00 } └── PhysicalProjection { exprs: [ #1, #2, #3, #0, #4, #5, #6, #7, #8, #9, #10 ], cost: weighted=278.64,row_cnt=1.00,compute=148.64,io=130.00 } └── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #3 ], cost: weighted=278.18,row_cnt=1.00,compute=148.18,io=130.00 } ├── PhysicalProjection { exprs: [ #3, #0, #1, #2 ], cost: weighted=76.12,row_cnt=1.00,compute=46.12,io=30.00 } │ └── PhysicalProjection { exprs: [ #0, #1, #2, #4 ], cost: weighted=75.94,row_cnt=1.00,compute=45.94,io=30.00 } │ └── PhysicalProjection { exprs: [ #1, #2, #3, #4, #0 ], cost: weighted=75.76,row_cnt=1.00,compute=45.76,io=30.00 } │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #2 ], cost: weighted=75.54,row_cnt=1.00,compute=45.54,io=30.00 } │ ├── PhysicalProjection { exprs: [ #0 ], cost: weighted=23.48,row_cnt=1.00,compute=18.48,io=5.00 } │ │ └── PhysicalFilter │ │ ├── cond:Eq │ │ │ ├── #1 │ │ │ └── "AMERICA" │ │ ├── cost: weighted=23.42,row_cnt=1.00,compute=18.42,io=5.00 │ │ └── PhysicalProjection { exprs: [ #0, #1 ], cost: weighted=5.30,row_cnt=5.00,compute=0.30,io=5.00 } │ │ └── PhysicalScan { table: region, cost: weighted=5.00,row_cnt=5.00,compute=0.00,io=5.00 } │ └── PhysicalScan { table: nation, cost: weighted=25.00,row_cnt=25.00,compute=0.00,io=25.00 } └── PhysicalScan { table: supplier, cost: weighted=100.00,row_cnt=100.00,compute=0.00,io=100.00 } plan space size budget used, not applying logical rules any more. current plan space: 1094 qerrors: {"DataFusion": [5.0]} ``` --- Cargo.lock | 1 + .../src/cost/adaptive_cost.rs | 8 +- optd-datafusion-repr/src/cost/base_cost.rs | 224 +++++++++++++++++- optd-gungnir/src/stats/hyperloglog.rs | 10 + optd-gungnir/src/stats/tdigest.rs | 3 + optd-perftest/Cargo.toml | 1 + optd-perftest/src/cardtest.rs | 2 +- ...fusion_db_cardtest.rs => datafusion_db.rs} | 113 +++++++-- optd-perftest/src/lib.rs | 2 +- 9 files changed, 324 insertions(+), 40 deletions(-) rename optd-perftest/src/{datafusion_db_cardtest.rs => datafusion_db.rs} (66%) diff --git a/Cargo.lock b/Cargo.lock index 5dbd3d26..39357a49 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2737,6 +2737,7 @@ dependencies = [ "log", "optd-datafusion-bridge", "optd-datafusion-repr", + "optd-gungnir", "regex", "shlex", "tokio", diff --git a/optd-datafusion-repr/src/cost/adaptive_cost.rs b/optd-datafusion-repr/src/cost/adaptive_cost.rs index a0b07309..35625782 100644 --- a/optd-datafusion-repr/src/cost/adaptive_cost.rs +++ b/optd-datafusion-repr/src/cost/adaptive_cost.rs @@ -51,15 +51,13 @@ impl CostModel for AdaptiveCostModel { ) -> Cost { if let OptRelNodeTyp::PhysicalScan = node { let guard = self.runtime_row_cnt.lock().unwrap(); - if let Some((runtime_row_cnt, iter)) = guard.history.get(&context.unwrap().group_id) { + if let Some((runtime_row_cnt, iter)) = + guard.history.get(&context.as_ref().unwrap().group_id) + { if *iter + self.decay >= guard.iter_cnt { let runtime_row_cnt = (*runtime_row_cnt).max(1) as f64; return OptCostModel::cost(runtime_row_cnt, 0.0, runtime_row_cnt); - } else { - return OptCostModel::cost(1.0, 0.0, 1.0); } - } else { - return OptCostModel::cost(1.0, 0.0, 1.0); } } let (mut row_cnt, compute_cost, io_cost) = OptCostModel::cost_tuple( diff --git a/optd-datafusion-repr/src/cost/base_cost.rs b/optd-datafusion-repr/src/cost/base_cost.rs index f08c7b6f..3f5a6de8 100644 --- a/optd-datafusion-repr/src/cost/base_cost.rs +++ b/optd-datafusion-repr/src/cost/base_cost.rs @@ -6,12 +6,20 @@ use crate::{ plan_nodes::{OptRelNodeRef, OptRelNodeTyp}, properties::column_ref::ColumnRef, }; +use arrow_schema::{ArrowError, DataType}; +use datafusion::arrow::array::{ + Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, UInt16Array, + UInt32Array, UInt8Array, +}; use itertools::Itertools; use optd_core::{ cascades::{CascadesOptimizer, RelNodeContext}, cost::{Cost, CostModel}, rel_node::{RelNode, RelNodeTyp, Value}, }; +use optd_gungnir::stats::hyperloglog::{self, HyperLogLog}; +use optd_gungnir::stats::tdigest::{self, TDigest}; fn compute_plan_node_cost>( model: &C, @@ -34,9 +42,207 @@ pub struct OptCostModel { per_table_stats_map: BaseTableStats, } +struct MockMostCommonValues { + mcvs: HashMap, +} + +impl MockMostCommonValues { + pub fn empty() -> Self { + MockMostCommonValues { + mcvs: HashMap::new(), + } + } +} + +impl MostCommonValues for MockMostCommonValues { + fn freq(&self, value: &Value) -> Option { + self.mcvs.get(value).copied() + } + + fn total_freq(&self) -> f64 { + self.mcvs.values().sum() + } + + fn freq_over_pred(&self, pred: Box bool>) -> f64 { + self.mcvs + .iter() + .filter(|(val, _)| pred(val)) + .map(|(_, freq)| freq) + .sum() + } + + fn cnt(&self) -> usize { + self.mcvs.len() + } +} + pub struct PerTableStats { row_cnt: usize, - per_column_stats_vec: Vec, + per_column_stats_vec: Vec>, +} + +impl PerTableStats { + pub fn from_record_batches>>( + batch_iter: RecordBatchIterator, + ) -> anyhow::Result { + let schema = batch_iter.schema(); + let col_types = schema + .fields() + .iter() + .map(|f| f.data_type().clone()) + .collect_vec(); + let col_cnt = col_types.len(); + + let mut row_cnt = 0; + let mut mcvs = col_types + .iter() + .map(|col_type| { + if Self::is_type_supported(col_type) { + Some(MockMostCommonValues::empty()) + } else { + None + } + }) + .collect_vec(); + let mut distr = col_types + .iter() + .map(|col_type| { + if Self::is_type_supported(col_type) { + Some(TDigest::new(tdigest::DEFAULT_COMPRESSION)) + } else { + None + } + }) + .collect_vec(); + let mut hlls = vec![HyperLogLog::new(hyperloglog::DEFAULT_PRECISION); col_cnt]; + let mut null_cnt = vec![0; col_cnt]; + + for batch in batch_iter { + let batch = batch?; + row_cnt += batch.num_rows(); + + // Enumerate the columns. + for (i, col) in batch.columns().iter().enumerate() { + let col_type = &col_types[i]; + if Self::is_type_supported(col_type) { + // Update null cnt. + null_cnt[i] += col.null_count(); + + Self::generate_stats_for_column(col, col_type, &mut distr[i], &mut hlls[i]); + } + } + } + + // Assemble the per-column stats. + let mut per_column_stats_vec = Vec::with_capacity(col_cnt); + for i in 0..col_cnt { + per_column_stats_vec.push(if Self::is_type_supported(&col_types[i]) { + Some(PerColumnStats { + mcvs: Box::new(mcvs[i].take().unwrap()) as Box, + ndistinct: hlls[i].n_distinct(), + null_frac: null_cnt[i] as f64 / row_cnt as f64, + distr: Box::new(distr[i].take().unwrap()) as Box, + }) + } else { + None + }); + } + Ok(Self { + row_cnt, + per_column_stats_vec, + }) + } + + fn is_type_supported(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::Float32 + | DataType::Float64 + ) + } + + /// Generate statistics for a column. + fn generate_stats_for_column( + col: &Arc, + col_type: &DataType, + distr: &mut Option, + hll: &mut HyperLogLog, + ) { + macro_rules! generate_stats_for_col { + ({ $col:expr, $distr:expr, $hll:expr, $array_type:path, $to_f64:ident }) => {{ + let array = $col.as_any().downcast_ref::<$array_type>().unwrap(); + // Filter out `None` values. + let values = array.iter().filter_map(|x| x).collect::>(); + + // Update distribution. + *$distr = { + let mut f64_values = values.iter().map(|x| $to_f64(*x)).collect::>(); + Some($distr.take().unwrap().merge_values(&mut f64_values)) + }; + + // Update hll. + $hll.aggregate(&values); + }}; + } + + /// Convert a value to f64 with no out of range or precision loss. + fn to_f64_safe>(val: T) -> f64 { + val.into() + } + + /// Convert i128 to f64 with possible precision loss. + /// + /// Note: optd represents decimal with the significand as f64 (see `ConstantExpr::decimal`). + /// For instance 0.04 of type `Decimal128(15, 2)` is just 4.0, the type information + /// is discarded. Therefore we must use the significand to generate the statistics. + fn i128_to_f64(val: i128) -> f64 { + val as f64 + } + + match col_type { + DataType::Boolean => { + generate_stats_for_col!({ col, distr, hll, BooleanArray, to_f64_safe }) + } + DataType::Int8 => { + generate_stats_for_col!({ col, distr, hll, Int8Array, to_f64_safe }) + } + DataType::Int16 => { + generate_stats_for_col!({ col, distr, hll, Int16Array, to_f64_safe }) + } + DataType::Int32 => { + generate_stats_for_col!({ col, distr, hll, Int32Array, to_f64_safe }) + } + DataType::UInt8 => { + generate_stats_for_col!({ col, distr, hll, UInt8Array, to_f64_safe }) + } + DataType::UInt16 => { + generate_stats_for_col!({ col, distr, hll, UInt16Array, to_f64_safe }) + } + DataType::UInt32 => { + generate_stats_for_col!({ col, distr, hll, UInt32Array, to_f64_safe }) + } + DataType::Float32 => { + generate_stats_for_col!({ col, distr, hll, Float32Array, to_f64_safe }) + } + DataType::Float64 => { + generate_stats_for_col!({ col, distr, hll, Float64Array, to_f64_safe }) + } + DataType::Date32 => { + generate_stats_for_col!({ col, distr, hll, Date32Array, to_f64_safe }) + } + DataType::Decimal128(_, _) => { + generate_stats_for_col!({ col, distr, hll, Decimal128Array, i128_to_f64 }) + } + _ => unreachable!(), + } + } } pub struct PerColumnStats { @@ -45,7 +251,7 @@ pub struct PerColumnStats { // ndistinct _does_ include the values in mcvs // ndistinct _does not_ include nulls - ndistinct: i32, + ndistinct: u64, // postgres uses null_frac instead of something like "num_nulls" so we'll follow suit // my guess for why they use null_frac is because we only ever use the fraction of nulls, not the # @@ -445,7 +651,8 @@ impl OptCostModel { is_eq: bool, ) -> Option { if let Some(per_table_stats) = self.per_table_stats_map.get(table) { - if let Some(per_column_stats) = per_table_stats.per_column_stats_vec.get(col_idx) { + if let Some(Some(per_column_stats)) = per_table_stats.per_column_stats_vec.get(col_idx) + { let eq_freq = if let Some(freq) = per_column_stats.mcvs.freq(value) { freq } else { @@ -484,7 +691,8 @@ impl OptCostModel { is_col_eq_val: bool, ) -> Option { if let Some(per_table_stats) = self.per_table_stats_map.get(table) { - if let Some(per_column_stats) = per_table_stats.per_column_stats_vec.get(col_idx) { + if let Some(Some(per_column_stats)) = per_table_stats.per_column_stats_vec.get(col_idx) + { // because distr does not include the values in MCVs, we need to compute the CDFs there as well // because nulls return false in any comparison, they are never included when computing range selectivity let distr_leq_freq = per_column_stats.distr.cdf(value); @@ -555,7 +763,7 @@ impl OptCostModel { } impl PerTableStats { - pub fn new(row_cnt: usize, per_column_stats_vec: Vec) -> Self { + pub fn new(row_cnt: usize, per_column_stats_vec: Vec>) -> Self { Self { row_cnt, per_column_stats_vec, @@ -566,7 +774,7 @@ impl PerTableStats { impl PerColumnStats { pub fn new( mcvs: Box, - ndistinct: i32, + ndistinct: u64, null_frac: f64, distr: Box, ) -> Self { @@ -612,7 +820,7 @@ mod tests { } } - fn empty() -> Self { + pub fn empty() -> Self { MockMostCommonValues::new(vec![]) } } @@ -664,7 +872,7 @@ mod tests { OptCostModel::new( vec![( String::from(TABLE1_NAME), - PerTableStats::new(100, vec![per_column_stats]), + PerTableStats::new(100, vec![Some(per_column_stats)]), )] .into_iter() .collect(), diff --git a/optd-gungnir/src/stats/hyperloglog.rs b/optd-gungnir/src/stats/hyperloglog.rs index 584597f8..aca39eb2 100644 --- a/optd-gungnir/src/stats/hyperloglog.rs +++ b/optd-gungnir/src/stats/hyperloglog.rs @@ -8,6 +8,8 @@ use crate::stats::murmur2::murmur_hash; use std::cmp::max; +pub const DEFAULT_PRECISION: u8 = 12; + /// Trait to transform any object into a stream of bytes. pub trait ByteSerializable { fn to_bytes(&self) -> Vec; @@ -15,6 +17,7 @@ pub trait ByteSerializable { /// The HyperLogLog (HLL) structure to provide a statistical estimate of NDistinct. /// For safety reasons, HLLs can only count elements of the same ByteSerializable type. +#[derive(Clone)] pub struct HyperLogLog { registers: Vec, // The buckets to estimate HLL on (i.e. upper p bits). precision: u8, // The precision (p) of our HLL; 4 <= p <= 16. @@ -29,6 +32,13 @@ impl ByteSerializable for String { } } +// Serialize common data types for hashing (bool). +impl ByteSerializable for bool { + fn to_bytes(&self) -> Vec { + (*self as u8).to_bytes() + } +} + // Serialize common data types for hashing (numeric). macro_rules! impl_byte_serializable_for_numeric { ($($type:ty),*) => { diff --git a/optd-gungnir/src/stats/tdigest.rs b/optd-gungnir/src/stats/tdigest.rs index 7fe99536..7f24d08c 100644 --- a/optd-gungnir/src/stats/tdigest.rs +++ b/optd-gungnir/src/stats/tdigest.rs @@ -6,7 +6,10 @@ use itertools::Itertools; use std::f64::consts::PI; +pub const DEFAULT_COMPRESSION: f64 = 200.0; + /// The TDigest structure for the statistical aggregator to query quantiles. +#[derive(Clone)] pub struct TDigest { /// A sorted array of Centroids, according to their mean. centroids: Vec, diff --git a/optd-perftest/Cargo.toml b/optd-perftest/Cargo.toml index a7891d68..b9837918 100644 --- a/optd-perftest/Cargo.toml +++ b/optd-perftest/Cargo.toml @@ -16,6 +16,7 @@ datafusion = { version = "32.0.0", features = [ ] } optd-datafusion-repr = { path = "../optd-datafusion-repr" } optd-datafusion-bridge = { path = "../optd-datafusion-bridge" } +optd-gungnir = { path = "../optd-gungnir" } datafusion-optd-cli = { path = "../datafusion-optd-cli" } futures = "0.3" anyhow = { version = "1", features = ["backtrace"] } diff --git a/optd-perftest/src/cardtest.rs b/optd-perftest/src/cardtest.rs index 9ec19f61..5715dba7 100644 --- a/optd-perftest/src/cardtest.rs +++ b/optd-perftest/src/cardtest.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::path::Path; use crate::postgres_db::PostgresDb; -use crate::{benchmark::Benchmark, datafusion_db_cardtest::DatafusionDb, tpch::TpchConfig}; +use crate::{benchmark::Benchmark, datafusion_db::DatafusionDb, tpch::TpchConfig}; use anyhow::{self}; use async_trait::async_trait; diff --git a/optd-perftest/src/datafusion_db_cardtest.rs b/optd-perftest/src/datafusion_db.rs similarity index 66% rename from optd-perftest/src/datafusion_db_cardtest.rs rename to optd-perftest/src/datafusion_db.rs index e4b66a52..95d0da6f 100644 --- a/optd-perftest/src/datafusion_db_cardtest.rs +++ b/optd-perftest/src/datafusion_db.rs @@ -11,7 +11,11 @@ use crate::{ }; use async_trait::async_trait; use datafusion::{ - arrow::util::display::{ArrayFormatter, FormatOptions}, + arrow::{ + array::RecordBatchIterator, + csv::ReaderBuilder, + util::display::{ArrayFormatter, FormatOptions}, + }, execution::{ config::SessionConfig, context::{SessionContext, SessionState}, @@ -22,7 +26,7 @@ use datafusion::{ use datafusion_optd_cli::helper::unescape_input; use lazy_static::lazy_static; use optd_datafusion_bridge::{DatafusionCatalog, OptdQueryPlanner}; -use optd_datafusion_repr::{cost::BaseTableStats, DatafusionOptimizer}; +use optd_datafusion_repr::{cost::BaseTableStats, cost::PerTableStats, DatafusionOptimizer}; use regex::Regex; pub struct DatafusionDb { @@ -40,7 +44,6 @@ impl CardtestRunnerDBHelper for DatafusionDb { &mut self, benchmark: &Benchmark, ) -> anyhow::Result> { - self.clear_state().await?; self.load_benchmark_data(benchmark).await?; match benchmark { Benchmark::Test => unimplemented!(), @@ -52,7 +55,6 @@ impl CardtestRunnerDBHelper for DatafusionDb { &mut self, benchmark: &Benchmark, ) -> anyhow::Result> { - self.clear_state().await?; self.load_benchmark_data(benchmark).await?; match benchmark { Benchmark::Test => unimplemented!(), @@ -65,17 +67,21 @@ impl DatafusionDb { pub async fn new>(workspace_dpath: P) -> anyhow::Result { Ok(DatafusionDb { workspace_dpath: workspace_dpath.as_ref().to_path_buf(), - ctx: Self::new_session_ctx().await?, + ctx: Self::new_session_ctx(None).await?, }) } - /// Reset data and metadata. - async fn clear_state(&mut self) -> anyhow::Result<()> { - self.ctx = Self::new_session_ctx().await?; + /// Reset [`SessionContext`] to a clean state. But initializa the optimizer + /// with pre-generated statistics. + /// + /// A more ideal way to generate statistics would be to use the `ANALYZE` + /// command in SQL, but DataFusion does not support that yet. + async fn clear_state(&mut self, stats: Option) -> anyhow::Result<()> { + self.ctx = Self::new_session_ctx(stats).await?; Ok(()) } - async fn new_session_ctx() -> anyhow::Result { + async fn new_session_ctx(stats: Option) -> anyhow::Result { let session_config = SessionConfig::from_env()?.with_information_schema(true); let rn_config = RuntimeConfig::new(); let runtime_env = RuntimeEnv::new(rn_config.clone())?; @@ -84,7 +90,7 @@ impl DatafusionDb { SessionState::new_with_config_rt(session_config.clone(), Arc::new(runtime_env)); let optimizer: DatafusionOptimizer = DatafusionOptimizer::new_physical( Arc::new(DatafusionCatalog::new(state.catalog_list())), - BaseTableStats::default(), + stats.unwrap_or_default(), true, ); state = state.with_physical_optimizer_rules(vec![]); @@ -95,15 +101,15 @@ impl DatafusionDb { Ok(ctx) } - async fn execute(&self, sql: &str) -> anyhow::Result>> { + async fn execute(ctx: &SessionContext, sql: &str) -> anyhow::Result>> { let sql = unescape_input(sql)?; let dialect = Box::new(GenericDialect); let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; let mut result = Vec::new(); for statement in statements { let df = { - let plan = self.ctx.state().statement_to_plan(statement).await?; - self.ctx.execute_logical_plan(plan).await? + let plan = ctx.state().statement_to_plan(statement).await?; + ctx.execute_logical_plan(plan).await? }; let batches = df.collect().await?; @@ -159,7 +165,7 @@ impl DatafusionDb { } async fn eval_query_truecard(&self, sql: &str) -> anyhow::Result { - let rows = self.execute(sql).await?; + let rows = Self::execute(&self.ctx, sql).await?; let num_rows = rows.len(); Ok(num_rows) } @@ -168,7 +174,7 @@ impl DatafusionDb { lazy_static! { static ref ROW_CNT_RE: Regex = Regex::new(r"row_cnt=(\d+\.\d+)").unwrap(); } - let explains = self.execute(&format!("explain verbose {}", sql)).await?; + let explains = Self::execute(&self.ctx, &format!("explain verbose {}", sql)).await?; // Find first occurrence of row_cnt=... in the output. let row_cnt = explains .iter() @@ -195,9 +201,14 @@ impl DatafusionDb { } async fn load_tpch_data(&mut self, tpch_config: &TpchConfig) -> anyhow::Result<()> { + // Geenrate the tables. let tpch_kit = TpchKit::build(&self.workspace_dpath)?; tpch_kit.gen_tables(tpch_config)?; + // Generate the stats. + let stats = self.load_tpch_stats(&tpch_kit, tpch_config).await?; + self.clear_state(Some(stats)).await?; + // Create the tables. let ddls = fs::read_to_string(&tpch_kit.schema_fpath)?; let ddls = ddls @@ -206,18 +217,21 @@ impl DatafusionDb { .filter(|s| !s.is_empty()) .collect::>(); for ddl in ddls { - self.execute(ddl).await?; + Self::execute(&self.ctx, ddl).await?; } // Load the data by creating an external table first and copying the data to real tables. let tbl_fpath_iter = tpch_kit.get_tbl_fpath_iter(tpch_config).unwrap(); for tbl_fpath in tbl_fpath_iter { let tbl_name = tbl_fpath.file_stem().unwrap().to_str().unwrap(); - self.execute(&format!( - "create external table {}_tbl stored as csv delimiter '|' location '{}';", - tbl_name, - tbl_fpath.to_str().unwrap() - )) + Self::execute( + &self.ctx, + &format!( + "create external table {}_tbl stored as csv delimiter '|' location '{}';", + tbl_name, + tbl_fpath.to_str().unwrap() + ), + ) .await?; // Get the number of columns of this table. @@ -235,14 +249,63 @@ impl DatafusionDb { .map(|i| format!("column_{}", i)) .collect::>() .join(", "); - self.execute(&format!( - "insert into {} select {} from {}_tbl;", - tbl_name, projection_list, tbl_name, - )) + Self::execute( + &self.ctx, + &format!( + "insert into {} select {} from {}_tbl;", + tbl_name, projection_list, tbl_name, + ), + ) .await?; } + Ok(()) } + + async fn load_tpch_stats( + &self, + tpch_kit: &TpchKit, + tpch_config: &TpchConfig, + ) -> anyhow::Result { + // To get the schema of each table. + let ctx = Self::new_session_ctx(None).await?; + let ddls = fs::read_to_string(&tpch_kit.schema_fpath)?; + let ddls = ddls + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect::>(); + for ddl in ddls { + Self::execute(&ctx, ddl).await?; + } + let mut base_table_stats = BaseTableStats::default(); + for tbl_fpath in tpch_kit.get_tbl_fpath_iter(tpch_config).unwrap() { + let tbl_name = tbl_fpath.file_stem().unwrap().to_str().unwrap(); + let schema = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table(tbl_name) + .await + .unwrap() + .schema(); + // Load the .tbl file into record batches using arrow. + let tbl_file = fs::File::open(&tbl_fpath)?; + let csv_reader = ReaderBuilder::new(schema.clone()) + .has_header(false) + .with_delimiter(b'|') + .build(tbl_file) + .unwrap(); + let batch_iter = RecordBatchIterator::new(csv_reader, schema); + base_table_stats.insert( + tbl_name.to_string(), + PerTableStats::from_record_batches(batch_iter)?, + ); + log::debug!("statistics generated for table: {}", tbl_name); + } + Ok(base_table_stats) + } } unsafe impl Send for DatafusionDb {} diff --git a/optd-perftest/src/lib.rs b/optd-perftest/src/lib.rs index 56a6557c..96da0d56 100644 --- a/optd-perftest/src/lib.rs +++ b/optd-perftest/src/lib.rs @@ -1,6 +1,6 @@ mod benchmark; pub mod cardtest; -mod datafusion_db_cardtest; +mod datafusion_db; mod postgres_db; pub mod shell; pub mod tpch;