diff --git a/Cargo.lock b/Cargo.lock index 149d4264..f0ea1893 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3886,7 +3886,7 @@ dependencies = [ [[package]] name = "sqlplannertest" version = "0.1.0" -source = "git+https://github.com/risinglightdb/sqlplannertest-rs?branch=main#4a0a5f9de80842de8808169262a1d7e27094d225" +source = "git+https://github.com/risinglightdb/sqlplannertest-rs?branch=main#6122d5be20383c0a6a50327258ea9da8e2d72df5" dependencies = [ "anyhow", "async-trait", diff --git a/optd-core/src/cascades/optimizer.rs b/optd-core/src/cascades/optimizer.rs index d5afc6da..c25d7ba8 100644 --- a/optd-core/src/cascades/optimizer.rs +++ b/optd-core/src/cascades/optimizer.rs @@ -26,7 +26,6 @@ pub type RuleId = usize; #[derive(Default, Clone, Debug)] pub struct OptimizerContext { - pub upper_bound: Option, pub budget_used: bool, pub rules_applied: usize, } diff --git a/optd-core/src/cascades/tasks/optimize_inputs.rs b/optd-core/src/cascades/tasks/optimize_inputs.rs index f84ab647..c8738aba 100644 --- a/optd-core/src/cascades/tasks/optimize_inputs.rs +++ b/optd-core/src/cascades/tasks/optimize_inputs.rs @@ -20,6 +20,23 @@ struct ContinueTask { return_from_optimize_group: bool, } +struct ContinueTaskDisplay<'a>(&'a Option); + +impl std::fmt::Display for ContinueTaskDisplay<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + Some(x) => { + if x.return_from_optimize_group { + write!(f, "return,next_group_idx={}", x.next_group_idx) + } else { + write!(f, "enter,next_group_idx={}", x.next_group_idx) + } + } + None => write!(f, "none"), + } + } +} + pub struct OptimizeInputsTask { expr_id: ExprId, continue_from: Option, @@ -124,7 +141,7 @@ impl Task for OptimizeInputsTask { let children_group_ids = &expr.children; let cost = optimizer.cost(); - trace!(event = "task_begin", task = "optimize_inputs", expr_id = %self.expr_id, continue_from = ?self.continue_from, total_children = %children_group_ids.len()); + trace!(event = "task_begin", task = "optimize_inputs", expr_id = %self.expr_id, continue_from = %ContinueTaskDisplay(&self.continue_from), total_children = %children_group_ids.len()); if let Some(ContinueTask { next_group_idx, @@ -170,11 +187,38 @@ impl Task for OptimizeInputsTask { Some(optimizer), ); let total_cost = cost.sum(&operation_cost, &input_cost); + + if self.pruning { + let group_info = optimizer.get_group_info(group_id); + fn trace_fmt(winner: &Winner) -> String { + match winner { + Winner::Full(winner) => winner.total_weighted_cost.to_string(), + Winner::Impossible => "impossible".to_string(), + Winner::Unknown => "unknown".to_string(), + } + } + trace!( + event = "compute_cost", + task = "optimize_inputs", + expr_id = %self.expr_id, + weighted_cost_so_far = cost.weighted_cost(&total_cost), + winner_weighted_cost = %trace_fmt(&group_info.winner), + current_processing = %next_group_idx, + total_child_groups = %children_group_ids.len()); + if let Some(winner) = group_info.winner.as_full_winner() { + let cost_so_far = cost.weighted_cost(&total_cost); + if winner.total_weighted_cost <= cost_so_far { + trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "pruned"); + return Ok(vec![]); + } + } + } + if next_group_idx < children_group_ids.len() { let child_group_id = children_group_ids[next_group_idx]; let group_idx = next_group_idx; - let group_info = optimizer.get_group_info(child_group_id); - if !group_info.winner.has_full_winner() { + let child_group_info = optimizer.get_group_info(child_group_id); + if !child_group_info.winner.has_full_winner() { if !return_from_optimize_group { trace!(event = "task_yield", task = "optimize_inputs", expr_id = %self.expr_id, group_idx = %group_idx, yield_to = "optimize_group", optimize_group_id = %child_group_id); return Ok(vec![ @@ -189,7 +233,7 @@ impl Task for OptimizeInputsTask { ]); } else { self.update_winner_impossible(optimizer); - trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, "result" = "impossible"); + trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "impossible"); return Ok(vec![]); } } @@ -203,7 +247,7 @@ impl Task for OptimizeInputsTask { )) as Box>]) } else { self.update_winner(input_statistics_ref, operation_cost, total_cost, optimizer); - trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, "result" = "optimized"); + trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "optimized"); Ok(vec![]) } } else { diff --git a/optd-sqlplannertest/src/bin/planner_test_apply.rs b/optd-sqlplannertest/src/bin/planner_test_apply.rs index 829186ca..595e5e76 100644 --- a/optd-sqlplannertest/src/bin/planner_test_apply.rs +++ b/optd-sqlplannertest/src/bin/planner_test_apply.rs @@ -3,14 +3,19 @@ use std::path::Path; use anyhow::Result; use clap::Parser; +use sqlplannertest::PlannerTestApplyOptions; #[derive(Parser)] #[command(version, about, long_about = None)] struct Cli { /// Optional list of directories to apply the test; if empty, apply all tests directories: Vec, + /// Use the advanced cost model #[clap(long)] enable_advanced_cost_model: bool, + /// Execute tests in serial + #[clap(long)] + serial: bool, } #[tokio::main] @@ -20,9 +25,11 @@ async fn main() -> Result<()> { let cli = Cli::parse(); let enable_advanced_cost_model = cli.enable_advanced_cost_model; + let opts = PlannerTestApplyOptions { serial: cli.serial }; + if cli.directories.is_empty() { println!("Running all tests"); - sqlplannertest::planner_test_apply( + sqlplannertest::planner_test_apply_with_options( Path::new(env!("CARGO_MANIFEST_DIR")).join("tests"), move || async move { if enable_advanced_cost_model { @@ -31,12 +38,13 @@ async fn main() -> Result<()> { optd_sqlplannertest::DatafusionDBMS::new().await } }, + opts, ) .await?; } else { for directory in cli.directories { println!("Running tests in {}", directory); - sqlplannertest::planner_test_apply( + sqlplannertest::planner_test_apply_with_options( Path::new(env!("CARGO_MANIFEST_DIR")) .join("tests") .join(directory), @@ -47,6 +55,7 @@ async fn main() -> Result<()> { optd_sqlplannertest::DatafusionDBMS::new().await } }, + opts.clone(), ) .await?; }