diff --git a/optd-core/src/cascades/optimizer.rs b/optd-core/src/cascades/optimizer.rs index a2e4ea7b..d24eec70 100644 --- a/optd-core/src/cascades/optimizer.rs +++ b/optd-core/src/cascades/optimizer.rs @@ -317,6 +317,10 @@ impl CascadesOptimizer { self.memo.merge_group(group_a, group_b); } + /// Get the properties of a Cascades group + /// P is the type of the property you expect + /// idx is the idx of the property you want. The order of properties is defined + /// by the property_builders parameter in CascadesOptimizer::new() pub fn get_property_by_group>( &self, group_id: GroupId, diff --git a/optd-datafusion-repr/src/cost/base_cost.rs b/optd-datafusion-repr/src/cost/base_cost.rs index 7b19e1ec..5fc0fb1c 100644 --- a/optd-datafusion-repr/src/cost/base_cost.rs +++ b/optd-datafusion-repr/src/cost/base_cost.rs @@ -1,11 +1,12 @@ use std::{collections::HashMap, sync::Arc}; use crate::plan_nodes::{ - BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, ExprList, LogOpType, OptRelNode, UnOpType, + BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, Expr, ExprList, LogOpExpr, LogOpType, + OptRelNode, UnOpType, }; use crate::properties::column_ref::{ColumnRefPropertyBuilder, GroupColumnRefs}; use crate::{ - plan_nodes::{OptRelNodeRef, OptRelNodeTyp}, + plan_nodes::{JoinType, OptRelNodeRef, OptRelNodeTyp}, properties::column_ref::ColumnRef, }; use arrow_schema::{ArrowError, DataType}; @@ -95,6 +96,11 @@ impl MostCommonValues for MockMostCommonValues { #[derive(Serialize, Deserialize)] pub struct PerTableStats { row_cnt: usize, + // This is a Vec of Options instead of just a Vec because some columns may not have stats + // due to their type being non-comparable. + // Further, I chose to represent it as a Vec of Options instead of a HashMap because a Vec + // of Options clearly differentiates between two different failure modes: "out-of-bounds + // access" and "column has no stats". per_column_stats_vec: Vec>>, } @@ -332,12 +338,13 @@ const DEFAULT_EQ_SEL: f64 = 0.005; const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333; // Default selectivity estimate for pattern-match operators such as LIKE const DEFAULT_MATCH_SEL: f64 = 0.005; +// Default n-distinct estimate for derived columns or columns lacking statistics +const DEFAULT_NUM_DISTINCT: u64 = 200; // Default selectivity if we have no information const DEFAULT_UNK_SEL: f64 = 0.005; -// Default n-distinct estimate for derived columns or columns lacking statistics -const DEFAULT_N_DISTINCT: u64 = 200; -const INVALID_SEL: f64 = 0.01; +// A placeholder for unimplemented!() for codepaths which are accessed by plannertest +const UNIMPLEMENTED_SEL: f64 = 0.01; impl OptCostModel { pub fn row_cnt(Cost(cost): &Cost) -> f64 { @@ -447,41 +454,45 @@ impl CostModel for OptCostM OptRelNodeTyp::PhysicalFilter => { let (row_cnt, _, _) = Self::cost_tuple(&children[0]); let (_, compute_cost, _) = Self::cost_tuple(&children[1]); - let selectivity = match context { - Some(context) => { - if let Some(optimizer) = optimizer { - let column_refs = optimizer - .get_property_by_group::( - context.group_id, - 1, - ); - let expr_group_id = context.children_group_ids[1]; - let expr_trees = optimizer.get_all_group_bindings(expr_group_id, false); - // there may be more than one expression tree in a group (you can see this trivially as you can just swap the order of two subtrees for commutative operators) - // however, we just take an arbitrary expression tree from the group to compute selectivity - if let Some(expr_tree) = expr_trees.first() { - self.get_filter_selectivity(Arc::clone(expr_tree), &column_refs) - } else { - panic!("encountered a PhysicalFilter without an expression") - } - } else { - panic!("compute_cost() should not be called if optimizer is None") - } - } - None => panic!("compute_cost() should not be called if context is None"), + let selectivity = if let (Some(context), Some(optimizer)) = (context, optimizer) { + let column_refs = optimizer + .get_property_by_group::(context.group_id, 1); + let expr_group_id = context.children_group_ids[1]; + let expr_trees = optimizer.get_all_group_bindings(expr_group_id, false); + // there may be more than one expression tree in a group (you can see this trivially as you can just swap the order of two subtrees for commutative operators) + // however, we just take an arbitrary expression tree from the group to compute selectivity + let expr_tree = expr_trees.first().expect("expression missing"); + self.get_filter_selectivity(expr_tree.clone(), &column_refs) + } else { + DEFAULT_UNK_SEL }; - Self::cost( (row_cnt * selectivity).max(1.0), row_cnt * compute_cost, 0.0, ) } - OptRelNodeTyp::PhysicalNestedLoopJoin(_) => { + OptRelNodeTyp::PhysicalNestedLoopJoin(join_typ) => { let (row_cnt_1, _, _) = Self::cost_tuple(&children[0]); let (row_cnt_2, _, _) = Self::cost_tuple(&children[1]); let (_, compute_cost, _) = Self::cost_tuple(&children[2]); - let selectivity = 0.01; + let selectivity = if let (Some(context), Some(optimizer)) = (context, optimizer) { + let column_refs = optimizer + .get_property_by_group::(context.group_id, 1); + let expr_group_id = context.children_group_ids[2]; + let expr_trees = optimizer.get_all_group_bindings(expr_group_id, false); + // there may be more than one expression tree in a group. see comment in OptRelNodeTyp::PhysicalFilter(_) for more information + let expr_tree = expr_trees.first().expect("expression missing"); + self.get_join_selectivity_from_expr_tree( + *join_typ, + expr_tree.clone(), + &column_refs, + row_cnt_1, + row_cnt_2, + ) + } else { + DEFAULT_UNK_SEL + }; Self::cost( (row_cnt_1 * row_cnt_2 * selectivity).max(1.0), row_cnt_1 * row_cnt_2 * compute_cost + row_cnt_1, @@ -493,11 +504,36 @@ impl CostModel for OptCostM let (_, compute_cost, _) = Self::cost_tuple(&children[1]); Self::cost(row_cnt, compute_cost * row_cnt, 0.0) } - OptRelNodeTyp::PhysicalHashJoin(_) => { + OptRelNodeTyp::PhysicalHashJoin(join_typ) => { let (row_cnt_1, _, _) = Self::cost_tuple(&children[0]); let (row_cnt_2, _, _) = Self::cost_tuple(&children[1]); + let selectivity = if let (Some(context), Some(optimizer)) = (context, optimizer) { + let column_refs = optimizer + .get_property_by_group::(context.group_id, 1); + let left_keys_group_id = context.children_group_ids[2]; + let right_keys_group_id = context.children_group_ids[3]; + let left_keys_list = + optimizer.get_all_group_bindings(left_keys_group_id, false); + let right_keys_list = + optimizer.get_all_group_bindings(right_keys_group_id, false); + // there may be more than one expression tree in a group. see comment in OptRelNodeTyp::PhysicalFilter(_) for more information + let left_keys = left_keys_list.first().expect("left keys missing"); + let right_keys = right_keys_list.first().expect("right keys missing"); + self.get_join_selectivity_from_keys( + *join_typ, + ExprList::from_rel_node(left_keys.clone()) + .expect("left_keys should be an ExprList"), + ExprList::from_rel_node(right_keys.clone()) + .expect("right_keys should be an ExprList"), + &column_refs, + row_cnt_1, + row_cnt_2, + ) + } else { + DEFAULT_UNK_SEL + }; Self::cost( - row_cnt_1.min(row_cnt_2).max(1.0), + (row_cnt_1 * row_cnt_2 * selectivity).max(1.0), row_cnt_1 * 2.0 + row_cnt_2, 0.0, ) @@ -595,10 +631,10 @@ impl OptCostModel { column_stats.ndistinct as f64 } else { // The column type is not supported or stats are missing. - DEFAULT_N_DISTINCT as f64 + DEFAULT_NUM_DISTINCT as f64 } } - ColumnRef::Derived => DEFAULT_N_DISTINCT as f64, + ColumnRef::Derived => DEFAULT_NUM_DISTINCT as f64, _ => panic!( "GROUP BY base table column ref must either be derived or base table" ), @@ -626,8 +662,8 @@ impl OptCostModel { ) -> f64 { assert!(expr_tree.typ.is_expression()); match &expr_tree.typ { - OptRelNodeTyp::Constant(_) => todo!("check bool type or else panic"), - OptRelNodeTyp::ColumnRef => todo!("check bool type or else panic"), + OptRelNodeTyp::Constant(_) => Self::get_constant_selectivity(expr_tree), + OptRelNodeTyp::ColumnRef => unimplemented!("check bool type or else panic"), OptRelNodeTyp::UnOp(un_op_typ) => { assert!(expr_tree.children.len() == 1); let child = expr_tree.child(0); @@ -646,7 +682,7 @@ impl OptCostModel { let right_child = expr_tree.child(1); if bin_op_typ.is_comparison() { - self.get_comparison_op_selectivity( + self.get_filter_comp_op_selectivity( *bin_op_typ, left_child, right_child, @@ -661,27 +697,254 @@ impl OptCostModel { } } OptRelNodeTyp::LogOp(log_op_typ) => { - self.get_log_op_selectivity(*log_op_typ, &expr_tree.children, column_refs) + self.get_filter_log_op_selectivity(*log_op_typ, &expr_tree.children, column_refs) } - OptRelNodeTyp::Func(_) => todo!("check bool type or else panic"), + OptRelNodeTyp::Func(_) => unimplemented!("check bool type or else panic"), OptRelNodeTyp::SortOrder(_) => { panic!("the selectivity of sort order expressions is undefined") } - OptRelNodeTyp::Between => INVALID_SEL, - OptRelNodeTyp::Cast => todo!("check bool type or else panic"), + OptRelNodeTyp::Between => UNIMPLEMENTED_SEL, + OptRelNodeTyp::Cast => unimplemented!("check bool type or else panic"), OptRelNodeTyp::Like => DEFAULT_MATCH_SEL, OptRelNodeTyp::DataType(_) => { panic!("the selectivity of a data type is not defined") } - OptRelNodeTyp::InList => INVALID_SEL, + OptRelNodeTyp::InList => UNIMPLEMENTED_SEL, _ => unreachable!( "all expression OptRelNodeTyp were enumerated. this should be unreachable" ), } } + /// Check if an expr_tree is a join condition, returning the join on col ref pair if it is. + /// The reason the check and the info are in the same function is because their code is almost identical. + /// It only picks out equality conditions between two column refs on different tables + fn get_on_col_ref_pair( + expr_tree: OptRelNodeRef, + column_refs: &GroupColumnRefs, + ) -> Option<(ColumnRefExpr, ColumnRefExpr)> { + // 1. Check that it's equality + if expr_tree.typ == OptRelNodeTyp::BinOp(BinOpType::Eq) { + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + // 2. Check that both sides are column refs + if left_child.typ == OptRelNodeTyp::ColumnRef + && right_child.typ == OptRelNodeTyp::ColumnRef + { + // 3. Check that both sides don't belong to the same table (if we don't know, that means they don't belong) + let left_col_ref_expr = ColumnRefExpr::from_rel_node(left_child) + .expect("we already checked that the type is ColumnRef"); + let right_col_ref_expr = ColumnRefExpr::from_rel_node(right_child) + .expect("we already checked that the type is ColumnRef"); + let left_col_ref = &column_refs[left_col_ref_expr.index()]; + let right_col_ref = &column_refs[right_col_ref_expr.index()]; + let is_same_table = if let ( + ColumnRef::BaseTableColumnRef { + table: left_table, .. + }, + ColumnRef::BaseTableColumnRef { + table: right_table, .. + }, + ) = (left_col_ref, right_col_ref) + { + left_table == right_table + } else { + false + }; + if !is_same_table { + Some((left_col_ref_expr, right_col_ref_expr)) + } else { + None + } + } else { + None + } + } else { + None + } + } + + /// The expr_tree input must be a "mixed expression tree", just like with get_filter_selectivity() + /// This is a "wrapper" to separate the equality conditions from the filter conditions before calling + /// the "main" get_join_selectivity_core() function. + fn get_join_selectivity_from_expr_tree( + &self, + join_typ: JoinType, + expr_tree: OptRelNodeRef, + column_refs: &GroupColumnRefs, + left_row_cnt: f64, + right_row_cnt: f64, + ) -> f64 { + assert!(expr_tree.typ.is_expression()); + if expr_tree.typ == OptRelNodeTyp::LogOp(LogOpType::And) { + let mut on_col_ref_pairs = vec![]; + let mut filter_expr_trees = vec![]; + for child_expr_tree in &expr_tree.children { + if let Some(on_col_ref_pair) = + Self::get_on_col_ref_pair(child_expr_tree.clone(), column_refs) + { + on_col_ref_pairs.push(on_col_ref_pair) + } else { + let child_expr = Expr::from_rel_node(child_expr_tree.clone()).expect( + "everything that is a direct child of an And node must be an expression", + ); + filter_expr_trees.push(child_expr); + } + } + assert!(on_col_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); + let filter_expr_tree = if filter_expr_trees.is_empty() { + None + } else { + Some( + LogOpExpr::new(LogOpType::And, ExprList::new(filter_expr_trees)) + .into_rel_node(), + ) + }; + self.get_join_selectivity_core( + join_typ, + on_col_ref_pairs, + filter_expr_tree, + column_refs, + left_row_cnt, + right_row_cnt, + ) + } else { + #[allow(clippy::collapsible_else_if)] + if let Some(on_col_ref_pair) = Self::get_on_col_ref_pair(expr_tree.clone(), column_refs) + { + self.get_join_selectivity_core( + join_typ, + vec![on_col_ref_pair], + None, + column_refs, + left_row_cnt, + right_row_cnt, + ) + } else { + self.get_join_selectivity_core( + join_typ, + vec![], + Some(expr_tree), + column_refs, + left_row_cnt, + right_row_cnt, + ) + } + } + } + + /// A wrapper to convert the join keys to the format expected by get_join_selectivity_core() + fn get_join_selectivity_from_keys( + &self, + join_typ: JoinType, + left_keys: ExprList, + right_keys: ExprList, + column_refs: &GroupColumnRefs, + left_row_cnt: f64, + right_row_cnt: f64, + ) -> f64 { + assert!(left_keys.len() == right_keys.len()); + // I assume that the keys are already in the right order s.t. the ith key of left_keys corresponds with the ith key of right_keys + let on_col_ref_pairs = left_keys + .to_vec() + .into_iter() + .zip(right_keys.to_vec()) + .map(|(left_key, right_key)| { + ( + ColumnRefExpr::from_rel_node(left_key.into_rel_node()) + .expect("keys should be ColumnRefExprs"), + ColumnRefExpr::from_rel_node(right_key.into_rel_node()) + .expect("keys should be ColumnRefExprs"), + ) + }) + .collect_vec(); + self.get_join_selectivity_core( + join_typ, + on_col_ref_pairs, + None, + column_refs, + left_row_cnt, + right_row_cnt, + ) + } + + /// The core logic of join selectivity which assumes we've already separated the expression into the on conditions and the filters + fn get_join_selectivity_core( + &self, + join_typ: JoinType, + on_col_ref_pairs: Vec<(ColumnRefExpr, ColumnRefExpr)>, + filter_expr_tree: Option, + column_refs: &GroupColumnRefs, + left_row_cnt: f64, + right_row_cnt: f64, + ) -> f64 { + let join_on_selectivity = self.get_join_on_selectivity(&on_col_ref_pairs, column_refs); + // Currently, there is no difference in how we handle a join filter and a select filter, so we use the same function + // One difference (that we *don't* care about right now) is that join filters can contain expressions from multiple + // different tables. Currently, this doesn't affect the get_filter_selectivity() function, but this may change in + // the future + let join_filter_selectivity = match filter_expr_tree { + Some(filter_expr_tree) => self.get_filter_selectivity(filter_expr_tree, column_refs), + None => 1.0, + }; + let inner_join_selectivity = join_on_selectivity * join_filter_selectivity; + match join_typ { + JoinType::Inner => inner_join_selectivity, + JoinType::LeftOuter => f64::max(inner_join_selectivity, 1.0 / right_row_cnt), + JoinType::RightOuter => f64::max(inner_join_selectivity, 1.0 / left_row_cnt), + JoinType::Cross => { + assert!( + on_col_ref_pairs.is_empty(), + "Cross joins should not have on columns" + ); + join_filter_selectivity + } + _ => unimplemented!("join_typ={} is not implemented", join_typ), + } + } + + fn get_per_column_stats_from_col_ref( + &self, + col_ref: &ColumnRef, + ) -> Option<&PerColumnStats> { + if let ColumnRef::BaseTableColumnRef { table, col_idx } = col_ref { + self.get_per_column_stats(table, *col_idx) + } else { + None + } + } + + fn get_per_column_stats(&self, table: &str, col_idx: usize) -> Option<&PerColumnStats> { + self.per_table_stats_map + .get(table) + .and_then(|per_table_stats| per_table_stats.per_column_stats_vec[col_idx].as_ref()) + } + + /// Get the selectivity of the on conditions + /// Note that the selectivity of the on conditions does not depend on join type. Join type is accounted for separately in get_join_selectivity_core() + fn get_join_on_selectivity( + &self, + on_col_ref_pairs: &[(ColumnRefExpr, ColumnRefExpr)], + column_refs: &GroupColumnRefs, + ) -> f64 { + // multiply the selectivities of all individual conditions together + on_col_ref_pairs.iter().map(|on_col_ref_pair| { + // the formula for each pair is min(1 / ndistinct1, 1 / ndistinct2) (see https://postgrespro.com/blog/pgsql/5969618) + let ndistincts = vec![&on_col_ref_pair.0, &on_col_ref_pair.1].into_iter().map(|on_col_ref_expr| { + match self.get_per_column_stats_from_col_ref(&column_refs[on_col_ref_expr.index()]) { + Some(per_col_stats) => per_col_stats.ndistinct, + None => DEFAULT_NUM_DISTINCT, + } + }); + // using reduce(f64::min) is the idiomatic workaround to min() because f64 does not implement Ord due to NaN + let selectivity = ndistincts.map(|ndistinct| 1.0 / ndistinct as f64).reduce(f64::min).expect("reduce() only returns None if the iterator is empty, which is impossible since col_ref_exprs.len() == 2"); + assert!(!selectivity.is_nan(), "it should be impossible for selectivity to be NaN since n-distinct is never 0"); + selectivity + }).product() + } + /// Comparison operators are the base case for recursion in get_filter_selectivity() - fn get_comparison_op_selectivity( + fn get_filter_comp_op_selectivity( &self, comp_bin_op_typ: BinOpType, left: OptRelNodeRef, @@ -690,49 +953,27 @@ impl OptCostModel { ) -> f64 { assert!(comp_bin_op_typ.is_comparison()); - // it's more convenient to refer to the children based on whether they're column nodes or not - // rather than by left/right - let mut col_ref_nodes = vec![]; - let mut non_col_ref_nodes = vec![]; - let is_left_col_ref; // I intentionally performed moves on left and right. This way, we don't accidentally use them after this block - // We always want to use "col_ref_node" and "non_col_ref_node" instead of "left" or "right" - if left.as_ref().typ == OptRelNodeTyp::ColumnRef { - is_left_col_ref = true; - col_ref_nodes.push( - ColumnRefExpr::from_rel_node(left) - .expect("we already checked that the type is ColumnRef"), - ); - } else { - is_left_col_ref = false; - non_col_ref_nodes.push(left); - } - if right.as_ref().typ == OptRelNodeTyp::ColumnRef { - col_ref_nodes.push( - ColumnRefExpr::from_rel_node(right) - .expect("we already checked that the type is ColumnRef"), - ); - } else { - non_col_ref_nodes.push(right); - } + let (col_ref_exprs, non_col_ref_exprs, is_left_col_ref) = + Self::get_semantic_nodes(left, right); // handle the different cases of column nodes - if col_ref_nodes.is_empty() { - INVALID_SEL - } else if col_ref_nodes.len() == 1 { - let col_ref_node = col_ref_nodes - .pop() - .expect("we just checked that col_ref_nodes.len() == 1"); - let col_ref_idx = col_ref_node.index(); + if col_ref_exprs.is_empty() { + UNIMPLEMENTED_SEL + } else if col_ref_exprs.len() == 1 { + let col_ref_expr = col_ref_exprs + .first() + .expect("we just checked that col_ref_exprs.len() == 1"); + let col_ref_idx = col_ref_expr.index(); if let ColumnRef::BaseTableColumnRef { table, col_idx } = &column_refs[col_ref_idx] { - let non_col_ref_node = non_col_ref_nodes - .pop() - .expect("non_col_ref_nodes should have a value since col_ref_nodes.len() == 1"); + let non_col_ref_expr = non_col_ref_exprs + .first() + .expect("non_col_ref_exprs should have a value since col_ref_exprs.len() == 1"); - match non_col_ref_node.as_ref().typ { + match non_col_ref_expr.as_ref().typ { OptRelNodeTyp::Constant(_) => { - let value = non_col_ref_node + let value = non_col_ref_expr .as_ref() .data .as_ref() @@ -778,22 +1019,54 @@ impl OptCostModel { OptRelNodeTyp::BinOp(_) => { Self::get_default_comparison_op_selectivity(comp_bin_op_typ) } - OptRelNodeTyp::Cast => INVALID_SEL, + OptRelNodeTyp::Cast => UNIMPLEMENTED_SEL, _ => unimplemented!( "unhandled case of comparing a column ref node to {}", - non_col_ref_node.as_ref().typ + non_col_ref_expr.as_ref().typ ), } } else { - unimplemented!("non base table column refs need to be implemented") + Self::get_default_comparison_op_selectivity(comp_bin_op_typ) } - } else if col_ref_nodes.len() == 2 { + } else if col_ref_exprs.len() == 2 { Self::get_default_comparison_op_selectivity(comp_bin_op_typ) } else { - unreachable!("we could have at most pushed left and right into col_ref_nodes") + unreachable!("we could have at most pushed left and right into col_ref_exprs") } } + /// Convert the left and right child nodes of some operation to what they semantically are + /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped + fn get_semantic_nodes( + left: OptRelNodeRef, + right: OptRelNodeRef, + ) -> (Vec, Vec, bool) { + let mut col_ref_exprs = vec![]; + let mut non_col_ref_exprs = vec![]; + let is_left_col_ref; + // I intentionally performed moves on left and right. This way, we don't accidentally use them after this block + // We always want to use "col_ref_expr" and "non_col_ref_expr" instead of "left" or "right" + if left.as_ref().typ == OptRelNodeTyp::ColumnRef { + is_left_col_ref = true; + col_ref_exprs.push( + ColumnRefExpr::from_rel_node(left) + .expect("we already checked that the type is ColumnRef"), + ); + } else { + is_left_col_ref = false; + non_col_ref_exprs.push(left); + } + if right.as_ref().typ == OptRelNodeTyp::ColumnRef { + col_ref_exprs.push( + ColumnRefExpr::from_rel_node(right) + .expect("we already checked that the type is ColumnRef"), + ); + } else { + non_col_ref_exprs.push(right); + } + (col_ref_exprs, non_col_ref_exprs, is_left_col_ref) + } + /// The default selectivity of a comparison expression /// Used when one side of the comparison is a column while the other side is something too /// complex/impossible to evaluate (subquery, UDF, another column, we have no stats, etc.) @@ -809,6 +1082,33 @@ impl OptCostModel { } } + fn get_constant_selectivity(const_node: OptRelNodeRef) -> f64 { + if let OptRelNodeTyp::Constant(const_typ) = const_node.typ { + if matches!(const_typ, ConstantType::Bool) { + let value = const_node + .as_ref() + .data + .as_ref() + .expect("constants should have data"); + if let Value::Bool(bool_value) = value { + if *bool_value { + 1.0 + } else { + 0.0 + } + } else { + unreachable!( + "if the typ is ConstantType::Bool, the value should be a Value::Bool" + ) + } + } else { + panic!("selectivity is not defined on constants which are not bools") + } + } else { + panic!("get_constant_selectivity must be called on a constant") + } + } + /// Get the selectivity of an expression of the form "column equals value" (or "value equals column") /// Will handle the case of statistics missing /// Equality predicates are handled entirely differently from range predicates so this is its own function @@ -822,31 +1122,21 @@ impl OptCostModel { value: &Value, is_eq: bool, ) -> f64 { - if let Some(per_table_stats) = self.per_table_stats_map.get(table) { - if let Some(Some(per_column_stats)) = per_table_stats.per_column_stats_vec.get(col_idx) - { - let eq_freq = if let Some(freq) = per_column_stats.mcvs.freq(value) { - freq - } else { - let non_mcv_freq = 1.0 - per_column_stats.mcvs.total_freq(); - // always safe because usize is at least as large as i32 - let ndistinct_as_usize = per_column_stats.ndistinct as usize; - let non_mcv_cnt = ndistinct_as_usize - per_column_stats.mcvs.cnt(); - // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt - 1 if null_frac > 0 - (non_mcv_freq - per_column_stats.null_frac) / (non_mcv_cnt as f64) - }; - if is_eq { - eq_freq - } else { - 1.0 - eq_freq - per_column_stats.null_frac - } + if let Some(per_column_stats) = self.get_per_column_stats(table, col_idx) { + let eq_freq = if let Some(freq) = per_column_stats.mcvs.freq(value) { + freq } else { - #[allow(clippy::collapsible_else_if)] - if is_eq { - DEFAULT_EQ_SEL - } else { - 1.0 - DEFAULT_EQ_SEL - } + let non_mcv_freq = 1.0 - per_column_stats.mcvs.total_freq(); + // always safe because usize is at least as large as i32 + let ndistinct_as_usize = per_column_stats.ndistinct as usize; + let non_mcv_cnt = ndistinct_as_usize - per_column_stats.mcvs.cnt(); + // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt - 1 if null_frac > 0 + (non_mcv_freq - per_column_stats.null_frac) / (non_mcv_cnt as f64) + }; + if is_eq { + eq_freq + } else { + 1.0 - eq_freq - per_column_stats.null_frac } } else { #[allow(clippy::collapsible_else_if)] @@ -872,53 +1162,48 @@ impl OptCostModel { is_col_lt_val: bool, is_col_eq_val: bool, ) -> f64 { - if let Some(per_table_stats) = self.per_table_stats_map.get(table) { - if let Some(Some(per_column_stats)) = per_table_stats.per_column_stats_vec.get(col_idx) - { - // because distr does not include the values in MCVs, we need to compute the CDFs there as well - // because nulls return false in any comparison, they are never included when computing range selectivity - let distr_leq_freq = per_column_stats.distr.cdf(value); - let value_clone = value.clone(); // clone the value so that we can move it into the closure to avoid lifetime issues - // TODO: in a future PR, figure out how to make Values comparable. rn I just hardcoded as_i32() to work around this - let pred = Box::new(move |val: &Value| val.as_i32() <= value_clone.as_i32()); - let mcvs_leq_freq = per_column_stats.mcvs.freq_over_pred(pred); - let total_leq_freq = distr_leq_freq + mcvs_leq_freq; - - // depending on whether value is in mcvs or not, we use different logic to turn total_leq_cdf into total_lt_cdf - // this logic just so happens to be the exact same logic as get_column_equality_selectivity implements - let total_lt_freq = total_leq_freq - - self.get_column_equality_selectivity(table, col_idx, value, true); - - // use either total_leq_freq or total_lt_freq to get the selectivity - if is_col_lt_val { - if is_col_eq_val { - // this branch means <= - total_leq_freq - } else { - // this branch means < - total_lt_freq - } + if let Some(per_column_stats) = self.get_per_column_stats(table, col_idx) { + // because distr does not include the values in MCVs, we need to compute the CDFs there as well + // because nulls return false in any comparison, they are never included when computing range selectivity + let distr_leq_freq = per_column_stats.distr.cdf(value); + let value_clone = value.clone(); // clone the value so that we can move it into the closure to avoid lifetime issues + // TODO: in a future PR, figure out how to make Values comparable. rn I just hardcoded as_i32() to work around this + let pred = Box::new(move |val: &Value| val.as_i32() <= value_clone.as_i32()); + let mcvs_leq_freq = per_column_stats.mcvs.freq_over_pred(pred); + let total_leq_freq = distr_leq_freq + mcvs_leq_freq; + + // depending on whether value is in mcvs or not, we use different logic to turn total_leq_cdf into total_lt_cdf + // this logic just so happens to be the exact same logic as get_column_equality_selectivity implements + let total_lt_freq = + total_leq_freq - self.get_column_equality_selectivity(table, col_idx, value, true); + + // use either total_leq_freq or total_lt_freq to get the selectivity + if is_col_lt_val { + if is_col_eq_val { + // this branch means <= + total_leq_freq } else { - // clippy wants me to collapse this into an else if, but keeping two nested if else statements is clearer - #[allow(clippy::collapsible_else_if)] - if is_col_eq_val { - // this branch means >=, which is 1 - < - null_frac - // we need to subtract null_frac since that isn't included in >= either - 1.0 - total_lt_freq - per_column_stats.null_frac - } else { - // this branch means >. same logic as above - 1.0 - total_leq_freq - per_column_stats.null_frac - } + // this branch means < + total_lt_freq } } else { - DEFAULT_INEQ_SEL + // clippy wants me to collapse this into an else if, but keeping two nested if else statements is clearer + #[allow(clippy::collapsible_else_if)] + if is_col_eq_val { + // this branch means >=, which is 1 - < - null_frac + // we need to subtract null_frac since that isn't included in >= either + 1.0 - total_lt_freq - per_column_stats.null_frac + } else { + // this branch means >. same logic as above + 1.0 - total_leq_freq - per_column_stats.null_frac + } } } else { DEFAULT_INEQ_SEL } } - fn get_log_op_selectivity( + fn get_filter_log_op_selectivity( &self, log_op_typ: LogOpType, children: &[OptRelNodeRef], @@ -960,15 +1245,17 @@ mod tests { use std::collections::HashMap; use crate::{ + cost::base_cost::DEFAULT_EQ_SEL, plan_nodes::{ - BinOpExpr, BinOpType, ColumnRefExpr, ConstantExpr, Expr, ExprList, LogOpExpr, + BinOpExpr, BinOpType, ColumnRefExpr, ConstantExpr, Expr, ExprList, JoinType, LogOpExpr, LogOpType, OptRelNode, OptRelNodeRef, UnOpExpr, UnOpType, }, - properties::column_ref::ColumnRef, + properties::column_ref::{ColumnRef, GroupColumnRefs}, }; use super::{Distribution, MostCommonValues, OptCostModel, PerColumnStats, PerTableStats}; type TestPerColumnStats = PerColumnStats; + type TestOptCostModel = OptCostModel; struct TestMostCommonValues { mcvs: HashMap, @@ -1030,12 +1317,11 @@ mod tests { } } - const TABLE1_NAME: &str = "t1"; + const TABLE1_NAME: &str = "table1"; + const TABLE2_NAME: &str = "table2"; - // one column is sufficient for all filter selectivity predicates - fn create_one_column_cost_model( - per_column_stats: TestPerColumnStats, - ) -> OptCostModel { + // one column is sufficient for all filter selectivity tests + fn create_one_column_cost_model(per_column_stats: TestPerColumnStats) -> TestOptCostModel { OptCostModel::new( vec![( String::from(TABLE1_NAME), @@ -1046,6 +1332,42 @@ mod tests { ) } + /// Two columns is sufficient for all join selectivity tests + fn create_two_table_cost_model( + tbl1_per_column_stats: TestPerColumnStats, + tbl2_per_column_stats: TestPerColumnStats, + ) -> TestOptCostModel { + create_two_table_cost_model_custom_row_cnts( + tbl1_per_column_stats, + tbl2_per_column_stats, + 100, + 100, + ) + } + + /// We need custom row counts because some join algorithms rely on the row cnt + fn create_two_table_cost_model_custom_row_cnts( + tbl1_per_column_stats: TestPerColumnStats, + tbl2_per_column_stats: TestPerColumnStats, + tbl1_row_cnt: usize, + tbl2_row_cnt: usize, + ) -> TestOptCostModel { + OptCostModel::new( + vec![ + ( + String::from(TABLE1_NAME), + PerTableStats::new(tbl1_row_cnt, vec![Some(tbl1_per_column_stats)]), + ), + ( + String::from(TABLE2_NAME), + PerTableStats::new(tbl2_row_cnt, vec![Some(tbl2_per_column_stats)]), + ), + ] + .into_iter() + .collect(), + ) + } + fn col_ref(idx: u64) -> OptRelNodeRef { // this conversion is always safe because idx was originally a usize let idx_as_usize = idx as usize; @@ -1088,8 +1410,33 @@ mod tests { .into_rel_node() } + /// The reason this isn't an associated function of PerColumnStats is because that would require + /// adding an empty() function to the trait definitions of MostCommonValues and Distribution, + /// which I wanted to avoid + fn get_empty_per_col_stats() -> TestPerColumnStats { + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 0, + 0.0, + TestDistribution::empty(), + ) + } + + #[test] + fn test_filtersel_const() { + let cost_model = create_one_column_cost_model(get_empty_per_col_stats()); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(cnst(Value::Bool(true)), &vec![]), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model.get_filter_selectivity(cnst(Value::Bool(false)), &vec![]), + 0.0 + ); + } + #[test] - fn test_colref_eq_constint_in_mcv() { + fn test_filtersel_colref_eq_constint_in_mcv() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), 0, @@ -1113,7 +1460,7 @@ mod tests { } #[test] - fn test_colref_eq_constint_not_in_mcv_no_nulls() { + fn test_filtersel_colref_eq_constint_not_in_mcv_no_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::new(vec![(Value::Int32(1), 0.2), (Value::Int32(3), 0.44)]), 5, @@ -1137,7 +1484,7 @@ mod tests { } #[test] - fn test_colref_eq_constint_not_in_mcv_with_nulls() { + fn test_filtersel_colref_eq_constint_not_in_mcv_with_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::new(vec![(Value::Int32(1), 0.2), (Value::Int32(3), 0.44)]), 5, @@ -1162,7 +1509,7 @@ mod tests { /// I only have one test for NEQ since I'll assume that it uses the same underlying logic as EQ #[test] - fn test_colref_neq_constint_in_mcv() { + fn test_filtersel_colref_neq_constint_in_mcv() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), 0, @@ -1186,7 +1533,7 @@ mod tests { } #[test] - fn test_colref_leq_constint_no_mcvs_in_range() { + fn test_filtersel_colref_leq_constint_no_mcvs_in_range() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::empty(), 10, @@ -1210,7 +1557,7 @@ mod tests { } #[test] - fn test_colref_leq_constint_no_mcvs_in_range_with_nulls() { + fn test_filtersel_colref_leq_constint_no_mcvs_in_range_with_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::empty(), 10, @@ -1234,7 +1581,7 @@ mod tests { } #[test] - fn test_colref_leq_constint_with_mcvs_in_range_not_at_border() { + fn test_filtersel_colref_leq_constint_with_mcvs_in_range_not_at_border() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues { mcvs: vec![ @@ -1267,7 +1614,7 @@ mod tests { } #[test] - fn test_colref_leq_constint_with_mcv_at_border() { + fn test_filtersel_colref_leq_constint_with_mcv_at_border() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::new(vec![ (Value::Int32(6), 0.05), @@ -1296,7 +1643,7 @@ mod tests { } #[test] - fn test_colref_lt_constint_no_mcvs_in_range() { + fn test_filtersel_colref_lt_constint_no_mcvs_in_range() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::empty(), 10, @@ -1320,7 +1667,7 @@ mod tests { } #[test] - fn test_colref_lt_constint_no_mcvs_in_range_with_nulls() { + fn test_filtersel_colref_lt_constint_no_mcvs_in_range_with_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::empty(), 9, // 90% of the values aren't nulls since null_frac = 0.1. if there are 9 distinct non-null values, each will have 0.1 frequency @@ -1344,7 +1691,7 @@ mod tests { } #[test] - fn test_colref_lt_constint_with_mcvs_in_range_not_at_border() { + fn test_filtersel_colref_lt_constint_with_mcvs_in_range_not_at_border() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues { mcvs: vec![ @@ -1362,6 +1709,7 @@ mod tests { )); let expr_tree = bin_op(BinOpType::Lt, col_ref(0), cnst(Value::Int32(15))); let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), col_ref(0)); + // TODO(phw2): make column_refs a function let column_refs = vec![ColumnRef::BaseTableColumnRef { table: String::from(TABLE1_NAME), col_idx: 0, @@ -1377,7 +1725,7 @@ mod tests { } #[test] - fn test_colref_lt_constint_with_mcv_at_border() { + fn test_filtersel_colref_lt_constint_with_mcv_at_border() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues { mcvs: vec![ @@ -1412,7 +1760,7 @@ mod tests { /// I have fewer tests for GT since I'll assume that it uses the same underlying logic as LEQ /// The only interesting thing to test is that if there are nulls, those aren't included in GT #[test] - fn test_colref_gt_constint_no_nulls() { + fn test_filtersel_colref_gt_constint_no_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::empty(), 10, @@ -1436,7 +1784,7 @@ mod tests { } #[test] - fn test_colref_gt_constint_with_nulls() { + fn test_filtersel_colref_gt_constint_with_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::empty(), 10, @@ -1462,7 +1810,7 @@ mod tests { /// As with above, I have one test without nulls and one test with nulls #[test] - fn test_colref_geq_constint_no_nulls() { + fn test_filtersel_colref_geq_constint_no_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::empty(), 10, @@ -1486,7 +1834,7 @@ mod tests { } #[test] - fn test_colref_geq_constint_with_nulls() { + fn test_filtersel_colref_geq_constint_with_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::empty(), 9, // 90% of the values aren't nulls since null_frac = 0.1. if there are 9 distinct non-null values, each will have 0.1 frequency @@ -1511,7 +1859,7 @@ mod tests { } #[test] - fn test_and() { + fn test_filtersel_and() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues { mcvs: vec![ @@ -1551,7 +1899,7 @@ mod tests { } #[test] - fn test_or() { + fn test_filtersel_or() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues { mcvs: vec![ @@ -1591,7 +1939,7 @@ mod tests { } #[test] - fn test_not_no_nulls() { + fn test_filtersel_not_no_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), 0, @@ -1613,7 +1961,7 @@ mod tests { } #[test] - fn test_not_with_nulls() { + fn test_filtersel_not_with_nulls() { let cost_model = create_one_column_cost_model(TestPerColumnStats::new( TestMostCommonValues::new(vec![(Value::Int32(1), 0.3)]), 0, @@ -1635,4 +1983,641 @@ mod tests { 0.7 ); } + + /// A wrapper around get_join_selectivity_from_expr_tree that extracts the table row counts from the cost model + fn test_get_join_selectivity( + cost_model: &TestOptCostModel, + reverse_tables: bool, + join_typ: JoinType, + expr_tree: OptRelNodeRef, + column_refs: &GroupColumnRefs, + ) -> f64 { + let table1_row_cnt = cost_model.per_table_stats_map[TABLE1_NAME].row_cnt as f64; + let table2_row_cnt = cost_model.per_table_stats_map[TABLE2_NAME].row_cnt as f64; + if !reverse_tables { + cost_model.get_join_selectivity_from_expr_tree( + join_typ, + expr_tree, + column_refs, + table1_row_cnt, + table2_row_cnt, + ) + } else { + cost_model.get_join_selectivity_from_expr_tree( + join_typ, + expr_tree, + column_refs, + table2_row_cnt, + table1_row_cnt, + ) + } + } + + #[test] + fn test_joinsel_inner_const() { + let cost_model = create_one_column_cost_model(get_empty_per_col_stats()); + assert_approx_eq::assert_approx_eq!( + cost_model.get_join_selectivity_from_expr_tree( + JoinType::Inner, + cnst(Value::Bool(true)), + &vec![], + f64::NAN, + f64::NAN + ), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model.get_join_selectivity_from_expr_tree( + JoinType::Inner, + cnst(Value::Bool(false)), + &vec![], + f64::NAN, + f64::NAN + ), + 0.0 + ); + } + + #[test] + fn test_joinsel_inner_oncond() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + TestDistribution::empty(), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + TestDistribution::empty(), + ), + ); + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let column_refs = vec![ + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity(&cost_model, false, JoinType::Inner, expr_tree, &column_refs), + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &column_refs + ), + 0.2 + ); + } + + #[test] + fn test_joinsel_inner_and_of_onconds() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + TestDistribution::empty(), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + TestDistribution::empty(), + ), + ); + let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let eq1and0 = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and0.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq1and0.clone(), eq0and1.clone()]); + let column_refs = vec![ + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity(&cost_model, false, JoinType::Inner, expr_tree, &column_refs), + 0.04 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &column_refs + ), + 0.04 + ); + } + + #[test] + fn test_joinsel_inner_and_of_oncond_and_filter() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + TestDistribution::empty(), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + TestDistribution::empty(), + ), + ); + let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let eq100 = bin_op(BinOpType::Eq, col_ref(1), cnst(Value::Int32(100))); + let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq100.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), eq0and1.clone()]); + let column_refs = vec![ + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity(&cost_model, false, JoinType::Inner, expr_tree, &column_refs), + 0.05 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &column_refs + ), + 0.05 + ); + } + + #[test] + fn test_joinsel_inner_and_of_filters() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + TestDistribution::empty(), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + TestDistribution::empty(), + ), + ); + let neq12 = bin_op(BinOpType::Neq, col_ref(0), cnst(Value::Int32(12))); + let eq100 = bin_op(BinOpType::Eq, col_ref(1), cnst(Value::Int32(100))); + let expr_tree = log_op(LogOpType::And, vec![neq12.clone(), eq100.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), neq12.clone()]); + let column_refs = vec![ + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity(&cost_model, false, JoinType::Inner, expr_tree, &column_refs), + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &column_refs + ), + 0.2 + ); + } + + #[test] + fn test_joinsel_inner_colref_eq_colref_same_table_is_not_oncond() { + let cost_model = create_two_table_cost_model( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + TestDistribution::empty(), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + TestDistribution::empty(), + ), + ); + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(0)); + let column_refs = vec![ + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + ]; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity(&cost_model, false, JoinType::Inner, expr_tree, &column_refs), + DEFAULT_EQ_SEL + ); + } + + // We don't test joinsel or with oncond because if there is an oncond (on condition), the top-level operator must be an AND + + /// I made this helper function to avoid copying all eight lines over and over + fn assert_joinsel_outer_selectivities( + cost_model: &TestOptCostModel, + expr_tree: OptRelNodeRef, + expr_tree_rev: OptRelNodeRef, + column_refs: &GroupColumnRefs, + expected_table1_outer_sel: f64, + expected_table2_outer_sel: f64, + ) { + // all table 1 outer combinations + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::LeftOuter, + expr_tree.clone(), + column_refs + ), + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::LeftOuter, + expr_tree_rev.clone(), + column_refs + ), + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::RightOuter, + expr_tree.clone(), + column_refs + ), + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::RightOuter, + expr_tree_rev.clone(), + column_refs + ), + expected_table1_outer_sel + ); + // all table 2 outer combinations + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::LeftOuter, + expr_tree.clone(), + column_refs + ), + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::LeftOuter, + expr_tree_rev.clone(), + column_refs + ), + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::RightOuter, + expr_tree.clone(), + column_refs + ), + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::RightOuter, + expr_tree_rev.clone(), + column_refs + ), + expected_table2_outer_sel + ); + } + + /// Unique oncond means an oncondition on columns which are unique in both tables + /// There's only one case if both columns are unique and have different row counts: the inner will be < 1 / row count + /// of one table and = 1 / row count of another + #[test] + fn test_joinsel_outer_unique_oncond() { + let cost_model = create_two_table_cost_model_custom_row_cnts( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + TestDistribution::empty(), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + TestDistribution::empty(), + ), + 5, + 4, + ); + // the left/right of the join refers to the tables, not the order of columns in the predicate + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let column_refs = vec![ + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + ]; + // sanity check the expected inner sel + let expected_inner_sel = 0.2; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &column_refs + ), + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &column_refs + ), + expected_inner_sel + ); + // check the outer sels + assert_joinsel_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_rev, + &column_refs, + 0.25, + 0.2, + ); + } + + /// Non-unique oncond means the column is not unique in either table + /// Inner always >= row count means that the inner join result is >= 1 / the row count of both tables + #[test] + fn test_joinsel_outer_nonunique_oncond_inner_always_geq_rowcnt() { + let cost_model = create_two_table_cost_model_custom_row_cnts( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 5, + 0.0, + TestDistribution::empty(), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + TestDistribution::empty(), + ), + 10, + 8, + ); + // the left/right of the join refers to the tables, not the order of columns in the predicate + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let column_refs = vec![ + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + ]; + // sanity check the expected inner sel + let expected_inner_sel = 0.2; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &column_refs + ), + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &column_refs + ), + expected_inner_sel + ); + // check the outer sels + assert_joinsel_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_rev, + &column_refs, + 0.2, + 0.2, + ); + } + + /// Non-unique oncond means the column is not unique in either table + /// Inner sometimes < row count means that the inner join result < 1 / the row count of exactly one table. + /// Note that without a join filter, it's impossible to be less than the row count of both tables + #[test] + fn test_joinsel_outer_nonunique_oncond_inner_sometimes_lt_rowcnt() { + let cost_model = create_two_table_cost_model_custom_row_cnts( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 10, + 0.0, + TestDistribution::empty(), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 2, + 0.0, + TestDistribution::empty(), + ), + 20, + 4, + ); + // the left/right of the join refers to the tables, not the order of columns in the predicate + let expr_tree = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let column_refs = vec![ + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + ]; + // sanity check the expected inner sel + let expected_inner_sel = 0.1; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &column_refs + ), + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &column_refs + ), + expected_inner_sel + ); + // check the outer sels + assert_joinsel_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_rev, + &column_refs, + 0.25, + 0.1, + ); + } + + /// Unique oncond means an oncondition on columns which are unique in both tables + /// Filter means we're adding a join filter + /// There's only one case if both columns are unique and there's a filter: the inner will be < 1 / row count of both tables + #[test] + fn test_joinsel_outer_unique_oncond_filter() { + let cost_model = create_two_table_cost_model_custom_row_cnts( + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 50, + 0.0, + TestDistribution::new(vec![(Value::Int32(128), 0.4)]), + ), + TestPerColumnStats::new( + TestMostCommonValues::empty(), + 4, + 0.0, + TestDistribution::empty(), + ), + 50, + 4, + ); + // the left/right of the join refers to the tables, not the order of columns in the predicate + let eq0and1 = bin_op(BinOpType::Eq, col_ref(0), col_ref(1)); + let eq1and0 = bin_op(BinOpType::Eq, col_ref(1), col_ref(0)); + let filter = bin_op(BinOpType::Leq, col_ref(0), cnst(Value::Int32(128))); + let expr_tree = log_op(LogOpType::And, vec![eq0and1, filter.clone()]); + // inner rev means its the inner expr (the eq op) whose children are being reversed, as opposed to the and op + let expr_tree_inner_rev = log_op(LogOpType::And, vec![eq1and0, filter.clone()]); + let column_refs = vec![ + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE1_NAME), + col_idx: 0, + }, + ColumnRef::BaseTableColumnRef { + table: String::from(TABLE2_NAME), + col_idx: 0, + }, + ]; + // sanity check the expected inner sel + let expected_inner_sel = 0.008; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &column_refs + ), + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_inner_rev.clone(), + &column_refs + ), + expected_inner_sel + ); + // check the outer sels + assert_joinsel_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_inner_rev, + &column_refs, + 0.25, + 0.02, + ); + } + + // I didn't test any non-unique cases with filter. The non-unique tests without filter should cover that } diff --git a/optd-datafusion-repr/src/plan_nodes.rs b/optd-datafusion-repr/src/plan_nodes.rs index faf27e3a..e872b3a9 100644 --- a/optd-datafusion-repr/src/plan_nodes.rs +++ b/optd-datafusion-repr/src/plan_nodes.rs @@ -39,6 +39,8 @@ pub use sort::{LogicalSort, PhysicalSort}; use crate::properties::schema::{Schema, SchemaPropertyBuilder}; +/// OptRelNodeTyp FAQ: +/// - The define_plan_node!() macro defines what the children of each join node are #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum OptRelNodeTyp { Placeholder(GroupId), diff --git a/optd-perftest/src/cardtest.rs b/optd-perftest/src/cardtest.rs index a7de677a..0b9158cb 100644 --- a/optd-perftest/src/cardtest.rs +++ b/optd-perftest/src/cardtest.rs @@ -103,14 +103,14 @@ pub trait CardtestRunnerDBMSHelper { pub async fn cardtest>( workspace_dpath: P, - no_cached_optd_stats: bool, + rebuild_cached_optd_stats: bool, pguser: &str, pgpassword: &str, tpch_config: TpchConfig, ) -> anyhow::Result>> { let pg_dbms = Box::new(PostgresDBMS::build(&workspace_dpath, pguser, pgpassword)?); let truecard_getter = pg_dbms.clone(); - let df_dbms = Box::new(DatafusionDBMS::new(&workspace_dpath, no_cached_optd_stats).await?); + let df_dbms = Box::new(DatafusionDBMS::new(&workspace_dpath, rebuild_cached_optd_stats).await?); let dbmss: Vec> = vec![pg_dbms, df_dbms]; let tpch_benchmark = Benchmark::Tpch(tpch_config.clone()); diff --git a/optd-perftest/src/datafusion_dbms.rs b/optd-perftest/src/datafusion_dbms.rs index e98d93e6..25f76b34 100644 --- a/optd-perftest/src/datafusion_dbms.rs +++ b/optd-perftest/src/datafusion_dbms.rs @@ -34,7 +34,7 @@ use regex::Regex; pub struct DatafusionDBMS { workspace_dpath: PathBuf, - no_cached_stats: bool, + rebuild_cached_stats: bool, ctx: SessionContext, } @@ -63,11 +63,11 @@ impl CardtestRunnerDBMSHelper for DatafusionDBMS { impl DatafusionDBMS { pub async fn new>( workspace_dpath: P, - no_cached_stats: bool, + rebuild_cached_stats: bool, ) -> anyhow::Result { Ok(DatafusionDBMS { workspace_dpath: workspace_dpath.as_ref().to_path_buf(), - no_cached_stats, + rebuild_cached_stats, ctx: Self::new_session_ctx(None).await?, }) } @@ -145,13 +145,13 @@ impl DatafusionDBMS { let mut estcards = vec![]; for (query_id, sql_fpath) in tpch_kit.get_sql_fpath_ordered_iter(tpch_config)? { - let sql = fs::read_to_string(sql_fpath)?; - let estcard = self.eval_query_estcard(&sql).await?; - estcards.push(estcard); println!( - "done evaluating datafusion's estcard for TPC-H Q{}", + "about to evaluate datafusion's estcard for TPC-H Q{}", query_id ); + let sql = fs::read_to_string(sql_fpath)?; + let estcard = self.eval_query_estcard(&sql).await?; + estcards.push(estcard); } Ok(estcards) @@ -213,7 +213,7 @@ impl DatafusionDBMS { .workspace_dpath .join("datafusion_stats_caches") .join(format!("{}.json", benchmark_fname)); - if !self.no_cached_stats && stats_cache_fpath.exists() { + if !self.rebuild_cached_stats && stats_cache_fpath.exists() { let file = File::open(&stats_cache_fpath)?; Ok(serde_json::from_reader(file)?) } else { @@ -222,9 +222,8 @@ impl DatafusionDBMS { _ => unimplemented!(), }; - // regardless of whether self.no_cached_stats is true or false, we want to update the cache - // this way, even if we choose not to read from the cache, the cache still always has the - // most up to date version of the stats + // When self.rebuild_cached_stats is true, we *don't read* from the cache but we still + // *do write* to the cache. fs::create_dir_all(stats_cache_fpath.parent().unwrap())?; let file = File::create(&stats_cache_fpath)?; serde_json::to_writer(file, &base_table_stats)?; diff --git a/optd-perftest/src/main.rs b/optd-perftest/src/main.rs index 0611b746..6a28cfd0 100644 --- a/optd-perftest/src/main.rs +++ b/optd-perftest/src/main.rs @@ -39,11 +39,11 @@ enum Commands { #[clap(long)] #[clap(action)] #[clap(help = "Whether to use the cached optd stats/cache generated stats")] - // this is an option because you want to make it false whenever you update the + // this is an option because you want to make it true whenever you update the // code for how stats are generated in optd, in order to not use cached stats // I found that I almost always want to use the cache though, which is why the // system will use the cache by default - no_cached_optd_stats: bool, + rebuild_cached_optd_stats: bool, #[clap(long)] #[clap(default_value = "default_user")] @@ -77,7 +77,7 @@ async fn main() -> anyhow::Result<()> { scale_factor, seed, query_ids, - no_cached_optd_stats, + rebuild_cached_optd_stats, pguser, pgpassword, } => { @@ -89,7 +89,7 @@ async fn main() -> anyhow::Result<()> { }; let cardinfo_alldbs = cardtest::cardtest( &workspace_dpath, - no_cached_optd_stats, + rebuild_cached_optd_stats, &pguser, &pgpassword, tpch_config, diff --git a/optd-perftest/tests/cardtest_integration.rs b/optd-perftest/tests/cardtest_integration.rs index 8b5c242d..327d4fa7 100644 --- a/optd-perftest/tests/cardtest_integration.rs +++ b/optd-perftest/tests/cardtest_integration.rs @@ -44,7 +44,7 @@ mod tests { // make sure scale factor is low so the test runs fast "--scale-factor", "0.01", - "--no-cached-optd-stats", + "--rebuild-cached-optd-stats", "--pguser", "test_user", "--pgpassword",