Skip to content

Commit

Permalink
feat: add cost estimation for agg (#144)
Browse files Browse the repository at this point in the history
Compute cost for aggregation. The cardinality is computed as the product
of n-distinct of all the group-by columns. If there's no group by
column, the output cardinality is just 1.

This should fix the cardinality parity between postgres for Q14 and Q17.
It also leads to a better join order in Q11, since aggregation is the
child of a join.

## Misc

- Add planner test for Q11.
- Fixes Q14 and Q17.
- Next step is to support n-distinct for string.
- We may change to multi-dimension n-distinct after it's integrated.
  • Loading branch information
Gun9niR authored Mar 31, 2024
1 parent f42a3cd commit ee080d8
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 93 deletions.
132 changes: 99 additions & 33 deletions optd-datafusion-repr/src/cost/base_cost.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};

use crate::plan_nodes::{
BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, LogOpType, OptRelNode, UnOpType,
BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, ExprList, LogOpType, OptRelNode, UnOpType,
};
use crate::properties::column_ref::{ColumnRefPropertyBuilder, GroupColumnRefs};
use crate::{
Expand All @@ -11,8 +11,8 @@ use crate::{
use arrow_schema::{ArrowError, DataType};
use datafusion::arrow::array::{
Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array,
Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, UInt16Array,
UInt32Array, UInt8Array,
Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
UInt16Array, UInt32Array, UInt8Array,
};
use itertools::Itertools;
use optd_core::{
Expand All @@ -22,6 +22,7 @@ use optd_core::{
};
use optd_gungnir::stats::hyperloglog::{self, HyperLogLog};
use optd_gungnir::stats::tdigest::{self, TDigest};
use optd_gungnir::utils::arith_encoder;
use serde::{Deserialize, Serialize};

fn compute_plan_node_cost<T: RelNodeTyp, C: CostModel<T>>(
Expand Down Expand Up @@ -181,6 +182,7 @@ impl DataFusionPerTableStats {
| DataType::UInt32
| DataType::Float32
| DataType::Float64
| DataType::Utf8
)
}

Expand Down Expand Up @@ -222,6 +224,10 @@ impl DataFusionPerTableStats {
val as f64
}

fn str_to_f64(string: &str) -> f64 {
arith_encoder::encode(string)
}

match col_type {
DataType::Boolean => {
generate_stats_for_col!({ col, distr, hll, BooleanArray, to_f64_safe })
Expand Down Expand Up @@ -256,6 +262,9 @@ impl DataFusionPerTableStats {
DataType::Decimal128(_, _) => {
generate_stats_for_col!({ col, distr, hll, Decimal128Array, i128_to_f64 })
}
DataType::Utf8 => {
generate_stats_for_col!({ col, distr, hll, StringArray, str_to_f64 })
}
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -323,6 +332,10 @@ const DEFAULT_EQ_SEL: f64 = 0.005;
const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333;
// Default selectivity estimate for pattern-match operators such as LIKE
const DEFAULT_MATCH_SEL: f64 = 0.005;
// Default selectivity if we have no information
const DEFAULT_UNK_SEL: f64 = 0.005;
// Default n-distinct estimate for derived columns or columns lacking statistics
const DEFAULT_N_DISTINCT: u64 = 200;

const INVALID_SEL: f64 = 0.01;

Expand Down Expand Up @@ -401,37 +414,33 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostM
OptRelNodeTyp::PhysicalEmptyRelation => Self::cost(0.5, 0.01, 0.0),
OptRelNodeTyp::PhysicalLimit => {
let (row_cnt, compute_cost, _) = Self::cost_tuple(&children[0]);
let row_cnt = if let Some(context) = context {
if let Some(optimizer) = optimizer {
let mut fetch_expr =
optimizer.get_all_group_bindings(context.children_group_ids[2], false);
assert!(
fetch_expr.len() == 1,
"fetch expression should be the only expr in the group"
);
let fetch_expr = fetch_expr.pop().unwrap();
assert!(
matches!(
fetch_expr.typ,
OptRelNodeTyp::Constant(ConstantType::UInt64)
),
"fetch type can only be UInt64"
);
let fetch = ConstantExpr::from_rel_node(fetch_expr)
.unwrap()
.value()
.as_u64();
// u64::MAX represents None
if fetch == u64::MAX {
row_cnt
} else {
row_cnt.min(fetch as f64)
}
let row_cnt = if let (Some(context), Some(optimizer)) = (context, optimizer) {
let mut fetch_expr =
optimizer.get_all_group_bindings(context.children_group_ids[2], false);
assert!(
fetch_expr.len() == 1,
"fetch expression should be the only expr in the group"
);
let fetch_expr = fetch_expr.pop().unwrap();
assert!(
matches!(
fetch_expr.typ,
OptRelNodeTyp::Constant(ConstantType::UInt64)
),
"fetch type can only be UInt64"
);
let fetch = ConstantExpr::from_rel_node(fetch_expr)
.unwrap()
.value()
.as_u64();
// u64::MAX represents None
if fetch == u64::MAX {
row_cnt
} else {
panic!("compute_cost() should not be called if optimizer is None")
row_cnt.min(fetch as f64)
}
} else {
panic!("compute_cost() should not be called if context is None")
(row_cnt * DEFAULT_UNK_SEL).max(1.0)
};
Self::cost(row_cnt, compute_cost, 0.0)
}
Expand Down Expand Up @@ -499,10 +508,15 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostM
Self::cost(row_cnt, row_cnt * row_cnt.ln_1p().max(1.0), 0.0)
}
OptRelNodeTyp::PhysicalAgg => {
let (row_cnt, _, _) = Self::cost_tuple(&children[0]);
let child_row_cnt = Self::row_cnt(&children[0]);
let row_cnt = self.get_agg_row_cnt(context, optimizer, child_row_cnt);
let (_, compute_cost_1, _) = Self::cost_tuple(&children[1]);
let (_, compute_cost_2, _) = Self::cost_tuple(&children[2]);
Self::cost(row_cnt, row_cnt * (compute_cost_1 + compute_cost_2), 0.0)
Self::cost(
row_cnt,
child_row_cnt * (compute_cost_1 + compute_cost_2),
0.0,
)
}
OptRelNodeTyp::List => {
let compute_cost = children
Expand Down Expand Up @@ -544,6 +558,58 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
}
}

fn get_agg_row_cnt(
&self,
context: Option<RelNodeContext>,
optimizer: Option<&CascadesOptimizer<OptRelNodeTyp>>,
child_row_cnt: f64,
) -> f64 {
if let (Some(context), Some(optimizer)) = (context, optimizer) {
let group_by_id = context.children_group_ids[2];
let mut group_by_exprs: Vec<Arc<RelNode<OptRelNodeTyp>>> =
optimizer.get_all_group_bindings(group_by_id, false);
assert!(
group_by_exprs.len() == 1,
"ExprList expression should be the only expression in the GROUP BY group"
);
let group_by = group_by_exprs.pop().unwrap();
let group_by = ExprList::from_rel_node(group_by).unwrap();
if group_by.is_empty() {
1.0
} else {
// Multiply the n-distinct of all the group by columns.
// TODO: improve with multi-dimensional n-distinct
let base_table_col_refs = optimizer
.get_property_by_group::<ColumnRefPropertyBuilder>(context.group_id, 1);
base_table_col_refs
.iter()
.take(group_by.len())
.map(|col_ref| match col_ref {
ColumnRef::BaseTableColumnRef { table, col_idx } => {
let table_stats = self.per_table_stats_map.get(table);
let column_stats = table_stats.map(|table_stats| {
table_stats.per_column_stats_vec.get(*col_idx).unwrap()
});

if let Some(Some(column_stats)) = column_stats {
column_stats.ndistinct as f64
} else {
// The column type is not supported or stats are missing.
DEFAULT_N_DISTINCT as f64
}
}
ColumnRef::Derived => DEFAULT_N_DISTINCT as f64,
_ => panic!(
"GROUP BY base table column ref must either be derived or base table"
),
})
.product()
}
} else {
(child_row_cnt * DEFAULT_UNK_SEL).max(1.0)
}
}

/// The expr_tree input must be a "mixed expression tree"
/// An "expression node" refers to a RelNode that returns true for is_expression()
/// A "full expression tree" is where every node in the tree is an expression node
Expand Down
34 changes: 25 additions & 9 deletions optd-datafusion-repr/src/properties/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,26 @@ use std::sync::Arc;
use optd_core::property::PropertyBuilder;

use super::DEFAULT_NAME;
use crate::plan_nodes::{ConstantType, EmptyRelationData, OptRelNodeTyp};
use crate::plan_nodes::{ConstantType, EmptyRelationData, FuncType, OptRelNodeTyp};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Field {
pub name: String,
pub typ: ConstantType,
pub nullable: bool,
}

impl Field {
/// Generate a field that is only a place holder whose members are never used.
fn placeholder() -> Self {
Self {
name: DEFAULT_NAME.to_string(),
typ: ConstantType::Any,
nullable: true,
}
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Schema {
pub fields: Vec<Field>,
Expand Down Expand Up @@ -87,14 +99,18 @@ impl PropertyBuilder<OptRelNodeTyp> for SchemaPropertyBuilder {
Schema { fields }
}
OptRelNodeTyp::LogOp(_) => Schema {
fields: vec![
Field {
name: DEFAULT_NAME.to_string(),
typ: ConstantType::Any,
nullable: true
};
children.len()
],
fields: vec![Field::placeholder(); children.len()],
},
OptRelNodeTyp::Agg => {
let mut group_by_schema = children[1].clone();
let agg_schema = children[2].clone();
group_by_schema.fields.extend(agg_schema.fields);
group_by_schema
}
OptRelNodeTyp::Func(FuncType::Agg(_)) => Schema {
// TODO: this is just a place holder now.
// The real type should be the column type.
fields: vec![Field::placeholder()],
},
_ => Schema { fields: vec![] },
}
Expand Down
9 changes: 8 additions & 1 deletion optd-gungnir/src/stats/hyperloglog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ pub struct HyperLogLog {
alpha: f64, // The normal HLL multiplier factor.
}

// Serialize common data types for hashing (&str).
impl ByteSerializable for &str {
fn to_bytes(&self) -> Vec<u8> {
self.as_bytes().to_vec()
}
}

// Serialize common data types for hashing (String).
impl ByteSerializable for String {
fn to_bytes(&self) -> Vec<u8> {
self.as_bytes().to_vec()
self.as_str().to_bytes()
}
}

Expand Down
1 change: 0 additions & 1 deletion optd-sqlplannertest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ impl DatafusionDBMS {
task: &str,
flags: &[String],
) -> Result<()> {
println!("task_explain(): called on sql={}", sql);
use std::fmt::Write;

let with_logical = flags.contains(&"with_logical".to_string());
Expand Down
Loading

0 comments on commit ee080d8

Please sign in to comment.