From b49a33324baf961a685e074bef26384c4d94ef6b Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 2 May 2025 12:44:19 -0400 Subject: [PATCH 1/8] break out representative methods into new trait --- optd/src/memo/memory.rs | 120 ++++++++++++++++++++-------------------- optd/src/memo/traits.rs | 90 +++++++++++++----------------- 2 files changed, 98 insertions(+), 112 deletions(-) diff --git a/optd/src/memo/memory.rs b/optd/src/memo/memory.rs index 7b0a9313..532d956e 100644 --- a/optd/src/memo/memory.rs +++ b/optd/src/memo/memory.rs @@ -120,6 +120,30 @@ impl GoalState { } } +impl Representative for MemoryMemo { + async fn find_repr_group(&self, group_id: GroupId) -> GroupId { + self.repr_group.find(&group_id) + } + + async fn find_repr_goal(&self, goal_id: GoalId) -> GoalId { + self.repr_goal.find(&goal_id) + } + + async fn find_repr_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> LogicalExpressionId { + self.repr_logical_expr.find(&logical_expr_id) + } + + async fn find_repr_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> PhysicalExpressionId { + self.repr_physical_expr.find(&physical_expr_id) + } +} + impl Memoize for MemoryMemo { async fn merge_groups( &mut self, @@ -133,7 +157,7 @@ impl Memoize for MemoryMemo { &self, group_id: GroupId, ) -> MemoResult> { - let group_id = self.find_repr_group(group_id).await?; + let group_id = self.find_repr_group(group_id).await; let group = self .groups .get(&group_id) @@ -147,7 +171,7 @@ impl Memoize for MemoryMemo { group_id: GroupId, props: LogicalProperties, ) -> MemoResult<()> { - let group_id = self.find_repr_group(group_id).await?; + let group_id = self.find_repr_group(group_id).await; let group = self .groups .get_mut(&group_id) @@ -161,7 +185,7 @@ impl Memoize for MemoryMemo { &self, group_id: GroupId, ) -> MemoResult> { - let group_id = self.find_repr_group(group_id).await?; + let group_id = self.find_repr_group(group_id).await; let group = self .groups .get(&group_id) @@ -171,7 +195,7 @@ impl Memoize for MemoryMemo { } async fn get_any_logical_expr(&self, group_id: GroupId) -> MemoResult { - let group_id = self.find_repr_group(group_id).await?; + let group_id = self.find_repr_group(group_id).await; let group = self .groups .get(&group_id) @@ -189,7 +213,7 @@ impl Memoize for MemoryMemo { &self, logical_expr_id: LogicalExpressionId, ) -> MemoResult> { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; let maybe_group_id = self.logical_expr_group_index.get(&logical_expr_id).cloned(); Ok(maybe_group_id) } @@ -208,7 +232,7 @@ impl Memoize for MemoryMemo { &self, goal_id: GoalId, ) -> MemoResult> { - let goal_id = self.find_repr_goal(goal_id).await?; + let goal_id = self.find_repr_goal(goal_id).await; let maybe_best_costed = self .best_optimized_physical_expr_index .get(&goal_id) @@ -217,7 +241,7 @@ impl Memoize for MemoryMemo { } async fn get_all_goal_members(&self, goal_id: GoalId) -> MemoResult> { - let goal_id = self.find_repr_goal(goal_id).await?; + let goal_id = self.find_repr_goal(goal_id).await; let goal_state = self.goals.get(&goal_id).unwrap(); Ok(goal_state.members.iter().cloned().collect()) } @@ -227,7 +251,7 @@ impl Memoize for MemoryMemo { goal_id: GoalId, member: GoalMemberId, ) -> MemoResult> { - let goal_id = self.find_repr_goal(goal_id).await?; + let goal_id = self.find_repr_goal(goal_id).await; let member = self.find_repr_goal_member(member).await?; let goal_state = self.goals.get_mut(&goal_id).unwrap(); @@ -278,7 +302,7 @@ impl Memoize for MemoryMemo { &self, physical_expr_id: PhysicalExpressionId, ) -> MemoResult> { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; let (_, maybe_cost) = self .physical_exprs .get(&physical_expr_id) @@ -291,7 +315,7 @@ impl Memoize for MemoryMemo { physical_expr_id: PhysicalExpressionId, new_cost: Cost, ) -> MemoResult> { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; let (_, cost_mut) = self .physical_exprs .get_mut(&physical_expr_id) @@ -333,7 +357,7 @@ impl Memoize for MemoryMemo { logical_expr_id: LogicalExpressionId, rule: &TransformationRule, ) -> MemoResult { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; let status = self .transform_dependency .get(&logical_expr_id) @@ -348,7 +372,7 @@ impl Memoize for MemoryMemo { logical_expr_id: LogicalExpressionId, rule: &TransformationRule, ) -> MemoResult<()> { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; let status_map = self .transform_dependency .entry(logical_expr_id) @@ -371,8 +395,8 @@ impl Memoize for MemoryMemo { goal_id: GoalId, rule: &ImplementationRule, ) -> MemoResult { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; - let goal_id = self.find_repr_goal(goal_id).await?; + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let goal_id = self.find_repr_goal(goal_id).await; let status = self .implement_dependency .get(&logical_expr_id) @@ -388,7 +412,7 @@ impl Memoize for MemoryMemo { goal_id: GoalId, rule: &ImplementationRule, ) -> MemoResult<()> { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; let status_map = self .implement_dependency .entry(logical_expr_id) @@ -409,7 +433,7 @@ impl Memoize for MemoryMemo { &self, physical_expr_id: PhysicalExpressionId, ) -> MemoResult { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; let status = self .cost_dependency .get(&physical_expr_id) @@ -419,7 +443,7 @@ impl Memoize for MemoryMemo { } async fn set_cost_clean(&mut self, physical_expr_id: PhysicalExpressionId) -> MemoResult<()> { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; let entry = self.cost_dependency.entry(physical_expr_id); @@ -442,8 +466,8 @@ impl Memoize for MemoryMemo { rule: &TransformationRule, group_id: GroupId, ) -> MemoResult<()> { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; - let group_id = self.find_repr_group(group_id).await?; + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let group_id = self.find_repr_group(group_id).await; let status_map = self .transform_dependency .entry(logical_expr_id) @@ -471,9 +495,9 @@ impl Memoize for MemoryMemo { rule: &ImplementationRule, group_id: GroupId, ) -> MemoResult<()> { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; - let group_id = self.find_repr_group(group_id).await?; - let goal_id = self.find_repr_goal(goal_id).await?; + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let group_id = self.find_repr_group(group_id).await; + let goal_id = self.find_repr_goal(goal_id).await; let status_map = self .implement_dependency @@ -500,8 +524,8 @@ impl Memoize for MemoryMemo { physical_expr_id: PhysicalExpressionId, goal_id: GoalId, ) -> MemoResult<()> { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; - let goal_id = self.find_repr_goal(goal_id).await?; + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; + let goal_id = self.find_repr_goal(goal_id).await; match self.cost_dependency.entry(physical_expr_id) { Entry::Occupied(occupied) => { @@ -582,7 +606,7 @@ impl Memoize for MemoryMemo { &self, logical_expr_id: LogicalExpressionId, ) -> MemoResult { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await?; + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; let logical_expr = self .logical_exprs .get(&logical_expr_id) @@ -644,39 +668,13 @@ impl Memoize for MemoryMemo { &self, physical_expr_id: PhysicalExpressionId, ) -> MemoResult { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; let (physical_expr, _) = self .physical_exprs .get(&physical_expr_id) .ok_or(MemoError::PhysicalExprNotFound(physical_expr_id))?; Ok(physical_expr.clone()) } - - async fn find_repr_group(&self, group_id: GroupId) -> MemoResult { - let repr_group_id = self.repr_group.find(&group_id); - Ok(repr_group_id) - } - - async fn find_repr_goal(&self, goal_id: GoalId) -> MemoResult { - let repr_goal_id = self.repr_goal.find(&goal_id); - Ok(repr_goal_id) - } - - async fn find_repr_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoResult { - let repr_expr_id = self.repr_logical_expr.find(&logical_expr_id); - Ok(repr_expr_id) - } - - async fn find_repr_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoResult { - let repr_expr_id = self.repr_physical_expr.find(&physical_expr_id); - Ok(repr_expr_id) - } } impl MemoryMemo { @@ -691,7 +689,7 @@ impl MemoryMemo { for child in repr_logical_expr.children.iter() { match child { Child::Singleton(group_id) => { - let repr_group_id = self.find_repr_group(*group_id).await?; + let repr_group_id = self.find_repr_group(*group_id).await; new_children.push(Child::Singleton(repr_group_id)); } Child::VarLength(group_ids) => { @@ -707,7 +705,7 @@ impl MemoryMemo { let new_group_ids = futures::future::join_all(new_group_ids) .await .into_iter() - .collect::, _>>()?; + .collect(); new_children.push(Child::VarLength(new_group_ids)); } @@ -729,7 +727,7 @@ impl MemoryMemo { match child { Child::Singleton(goal_member_id) => { if let GoalMemberId::GoalId(goal_id) = goal_member_id { - let repr_goal_id = self.find_repr_goal(*goal_id).await?; + let repr_goal_id = self.find_repr_goal(*goal_id).await; new_children.push(Child::Singleton(GoalMemberId::GoalId(repr_goal_id))); } else { new_children.push(Child::Singleton(*goal_member_id)); @@ -740,12 +738,12 @@ impl MemoryMemo { for goal_member_id in goal_member_ids.iter() { match goal_member_id { GoalMemberId::GoalId(goal_id) => { - let repr_goal_id = self.find_repr_goal(*goal_id).await?; + let repr_goal_id = self.find_repr_goal(*goal_id).await; new_goal_member_ids.push(GoalMemberId::GoalId(repr_goal_id)); } GoalMemberId::PhysicalExpressionId(physical_expr_id) => { let repr_physical_expr_id = - self.find_repr_physical_expr(*physical_expr_id).await?; + self.find_repr_physical_expr(*physical_expr_id).await; new_goal_member_ids.push(GoalMemberId::PhysicalExpressionId( repr_physical_expr_id, )); @@ -933,8 +931,8 @@ impl MemoryMemo { group_id_2: GroupId, ) -> MemoResult> { // our strategy is to always merge group 2 into group 1. - let group_id_1 = self.find_repr_group(group_id_1).await?; - let group_id_2 = self.find_repr_group(group_id_2).await?; + let group_id_1 = self.find_repr_group(group_id_1).await; + let group_id_2 = self.find_repr_group(group_id_2).await; if group_id_1 == group_id_2 { return Ok(None); @@ -1121,11 +1119,11 @@ impl MemoryMemo { async fn find_repr_goal_member(&self, member: GoalMemberId) -> MemoResult { match member { GoalMemberId::PhysicalExpressionId(physical_expr_id) => { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await?; + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; Ok(GoalMemberId::PhysicalExpressionId(physical_expr_id)) } GoalMemberId::GoalId(goal_id) => { - let goal_id = self.find_repr_goal(goal_id).await?; + let goal_id = self.find_repr_goal(goal_id).await; Ok(GoalMemberId::GoalId(goal_id)) } } diff --git a/optd/src/memo/traits.rs b/optd/src/memo/traits.rs index cf4769cf..623aae4b 100644 --- a/optd/src/memo/traits.rs +++ b/optd/src/memo/traits.rs @@ -1,6 +1,44 @@ use super::{MemoResult, MergeProducts, PropagateBestExpression, TaskStatus}; use crate::core::cir::*; +/// A helper trait to help facilitate finding the representative IDs of elements. +#[trait_variant::make(Send)] +pub trait Representative { + /// Finds the representative group of a given group. The representative is usually tracked via a + /// Union-Find data structure. + /// + /// If the input group is already the representative, then the returned [`GroupId`] is equal to + /// the input [`GroupId`]. + async fn find_repr_group(&self, group_id: GroupId) -> GroupId; + + /// Finds the representative goal of a given goal. The representative is usually tracked via a + /// Union-Find data structure. + /// + /// If the input goal is already the representative, then the returned [`GoalId`] is equal to + /// the input [`GoalId`]. + async fn find_repr_goal(&self, goal_id: GoalId) -> GoalId; + + /// Finds the representative logical expression of a given expression. The representative is + /// usually tracked via a Union-Find data structure. + /// + /// If the input expression is already the representative, then the returned + /// [`LogicalExpressionId`] is equal to the input [`LogicalExpressionId`]. + async fn find_repr_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> LogicalExpressionId; + + /// Finds the representative physical expression of a given expression. The representative is + /// usually tracked via a Union-Find data structure. + /// + /// If the input expression is already the representative, then the returned + /// [`PhysicalExpressionId`] is equal to the input [`PhysicalExpressionId`]. + async fn find_repr_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> PhysicalExpressionId; +} + /// Core interface for memo-based query optimization. /// /// This trait defines the operations needed to store, retrieve, and manipulate @@ -8,7 +46,7 @@ use crate::core::cir::*; /// query optimization. The memo stores logical and physical expressions by their IDs, /// manages expression properties, and tracks optimization status. #[trait_variant::make(Send)] -pub trait Memoize: Send + Sync + 'static { +pub trait Memoize: Representative + Sync { // // Logical expression and group operations. // @@ -349,54 +387,4 @@ pub trait Memoize: Send + Sync + 'static { &self, physical_expr_id: PhysicalExpressionId, ) -> MemoResult; - - // - // Representative ID operations. - // - - /// Finds the representative group ID for a given group ID. - /// - /// # Parameters - /// * `group_id` - The group ID to find the representative for. - /// - /// # Returns - /// The representative group ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_group(&self, group_id: GroupId) -> MemoResult; - - /// Finds the representative goal ID for a given goal ID. - /// - /// # Parameters - /// * `goal_id` - The goal ID to find the representative for. - /// - /// # Returns - /// The representative goal ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_goal(&self, goal_id: GoalId) -> MemoResult; - - /// Finds the representative logical expression ID for a given logical expression ID. - /// - /// # Parameters - /// * `logical_expr_id` - The logical expression ID to find the representative for. - /// - /// # Returns - /// The representative logical expression ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoResult; - - /// Finds the representative physical expression ID for a given physical expression ID. - /// - /// # Parameters - /// * `physical_expr_id` - The physical expression ID to find the representative for. - /// - /// # Returns - /// The representative physical expression ID (which may be the same as the input if - /// it's already the representative). - async fn find_repr_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoResult; } From fdc23468830c4c3c8b4d0ec3144629f309a9d943 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 2 May 2025 12:47:28 -0400 Subject: [PATCH 2/8] break out materialization methods into trait --- optd/src/memo/memory.rs | 270 ++++++++++++++++++++-------------------- optd/src/memo/traits.rs | 109 ++++++---------- 2 files changed, 175 insertions(+), 204 deletions(-) diff --git a/optd/src/memo/memory.rs b/optd/src/memo/memory.rs index 532d956e..0d3ee559 100644 --- a/optd/src/memo/memory.rs +++ b/optd/src/memo/memory.rs @@ -144,6 +144,142 @@ impl Representative for MemoryMemo { } } +impl Materialize for MemoryMemo { + async fn get_goal_id(&mut self, goal: &Goal) -> MemoResult { + if let Some(goal_id) = self.goal_node_to_id_index.get(goal).cloned() { + return Ok(goal_id); + } + let goal_id = self.next_goal_id(); + self.goal_node_to_id_index.insert(goal.clone(), goal_id); + self.goals.insert(goal_id, GoalState::new(goal.clone())); + + let Goal(group_id, _) = goal; + self.groups.get_mut(group_id).unwrap().goals.insert(goal_id); + Ok(goal_id) + } + + async fn materialize_goal(&self, goal_id: GoalId) -> MemoResult { + let state = self + .goals + .get(&goal_id) + .ok_or(MemoError::GoalNotFound(goal_id))?; + + Ok(state.goal.clone()) + } + + async fn get_logical_expr_id( + &mut self, + logical_expr: &LogicalExpression, + ) -> MemoResult { + if let Some(logical_expr_id) = self + .logical_expr_node_to_id_index + .get(logical_expr) + .cloned() + { + return Ok(logical_expr_id); + } + let logical_expr_id = self.next_logical_expr_id(); + self.logical_expr_node_to_id_index + .insert(logical_expr.clone(), logical_expr_id); + self.logical_exprs + .insert(logical_expr_id, logical_expr.clone()); + + for child in logical_expr.children.iter() { + match child { + Child::Singleton(group_id) => { + self.group_dependent_logical_exprs + .entry(*group_id) + .or_default() + .insert(logical_expr_id); + } + Child::VarLength(group_ids) => { + for group_id in group_ids.iter() { + self.group_dependent_logical_exprs + .entry(*group_id) + .or_default() + .insert(logical_expr_id); + } + } + } + } + Ok(logical_expr_id) + } + + async fn materialize_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoResult { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let logical_expr = self + .logical_exprs + .get(&logical_expr_id) + .ok_or(MemoError::LogicalExprNotFound(logical_expr_id))?; + Ok(logical_expr.clone()) + } + + async fn get_physical_expr_id( + &mut self, + physical_expr: &PhysicalExpression, + ) -> MemoResult { + if let Some(physical_expr_id) = self + .physical_expr_node_to_id_index + .get(physical_expr) + .cloned() + { + return Ok(physical_expr_id); + } + let physical_expr_id = self.next_physical_expr_id(); + self.physical_expr_node_to_id_index + .insert(physical_expr.clone(), physical_expr_id); + self.physical_exprs + .insert(physical_expr_id, (physical_expr.clone(), None)); + + for child in physical_expr.children.iter() { + match child { + Child::Singleton(goal_member_id) => { + if let GoalMemberId::GoalId(goal_id) = goal_member_id { + self.goal_dependent_physical_exprs + .entry(*goal_id) + .or_default() + .insert(physical_expr_id); + } + } + Child::VarLength(goal_member_ids) => { + for goal_member_id in goal_member_ids.iter() { + match goal_member_id { + GoalMemberId::GoalId(goal_id) => { + self.goal_dependent_physical_exprs + .entry(*goal_id) + .or_default() + .insert(physical_expr_id); + } + GoalMemberId::PhysicalExpressionId(child_physical_expr_id) => { + self.physical_expr_dependent_physical_exprs + .entry(*child_physical_expr_id) + .or_default() + .insert(physical_expr_id); + } + } + } + } + } + } + Ok(physical_expr_id) + } + + async fn materialize_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoResult { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; + let (physical_expr, _) = self + .physical_exprs + .get(&physical_expr_id) + .ok_or(MemoError::PhysicalExprNotFound(physical_expr_id))?; + Ok(physical_expr.clone()) + } +} + impl Memoize for MemoryMemo { async fn merge_groups( &mut self, @@ -541,140 +677,6 @@ impl Memoize for MemoryMemo { Ok(()) } - - async fn get_goal_id(&mut self, goal: &Goal) -> MemoResult { - if let Some(goal_id) = self.goal_node_to_id_index.get(goal).cloned() { - return Ok(goal_id); - } - let goal_id = self.next_goal_id(); - self.goal_node_to_id_index.insert(goal.clone(), goal_id); - self.goals.insert(goal_id, GoalState::new(goal.clone())); - - let Goal(group_id, _) = goal; - self.groups.get_mut(group_id).unwrap().goals.insert(goal_id); - Ok(goal_id) - } - - async fn materialize_goal(&self, goal_id: GoalId) -> MemoResult { - let state = self - .goals - .get(&goal_id) - .ok_or(MemoError::GoalNotFound(goal_id))?; - - Ok(state.goal.clone()) - } - - async fn get_logical_expr_id( - &mut self, - logical_expr: &LogicalExpression, - ) -> MemoResult { - if let Some(logical_expr_id) = self - .logical_expr_node_to_id_index - .get(logical_expr) - .cloned() - { - return Ok(logical_expr_id); - } - let logical_expr_id = self.next_logical_expr_id(); - self.logical_expr_node_to_id_index - .insert(logical_expr.clone(), logical_expr_id); - self.logical_exprs - .insert(logical_expr_id, logical_expr.clone()); - - for child in logical_expr.children.iter() { - match child { - Child::Singleton(group_id) => { - self.group_dependent_logical_exprs - .entry(*group_id) - .or_default() - .insert(logical_expr_id); - } - Child::VarLength(group_ids) => { - for group_id in group_ids.iter() { - self.group_dependent_logical_exprs - .entry(*group_id) - .or_default() - .insert(logical_expr_id); - } - } - } - } - Ok(logical_expr_id) - } - - async fn materialize_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoResult { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; - let logical_expr = self - .logical_exprs - .get(&logical_expr_id) - .ok_or(MemoError::LogicalExprNotFound(logical_expr_id))?; - Ok(logical_expr.clone()) - } - - async fn get_physical_expr_id( - &mut self, - physical_expr: &PhysicalExpression, - ) -> MemoResult { - if let Some(physical_expr_id) = self - .physical_expr_node_to_id_index - .get(physical_expr) - .cloned() - { - return Ok(physical_expr_id); - } - let physical_expr_id = self.next_physical_expr_id(); - self.physical_expr_node_to_id_index - .insert(physical_expr.clone(), physical_expr_id); - self.physical_exprs - .insert(physical_expr_id, (physical_expr.clone(), None)); - - for child in physical_expr.children.iter() { - match child { - Child::Singleton(goal_member_id) => { - if let GoalMemberId::GoalId(goal_id) = goal_member_id { - self.goal_dependent_physical_exprs - .entry(*goal_id) - .or_default() - .insert(physical_expr_id); - } - } - Child::VarLength(goal_member_ids) => { - for goal_member_id in goal_member_ids.iter() { - match goal_member_id { - GoalMemberId::GoalId(goal_id) => { - self.goal_dependent_physical_exprs - .entry(*goal_id) - .or_default() - .insert(physical_expr_id); - } - GoalMemberId::PhysicalExpressionId(child_physical_expr_id) => { - self.physical_expr_dependent_physical_exprs - .entry(*child_physical_expr_id) - .or_default() - .insert(physical_expr_id); - } - } - } - } - } - } - Ok(physical_expr_id) - } - - async fn materialize_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoResult { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; - let (physical_expr, _) = self - .physical_exprs - .get(&physical_expr_id) - .ok_or(MemoError::PhysicalExprNotFound(physical_expr_id))?; - Ok(physical_expr.clone()) - } } impl MemoryMemo { diff --git a/optd/src/memo/traits.rs b/optd/src/memo/traits.rs index 623aae4b..be74f573 100644 --- a/optd/src/memo/traits.rs +++ b/optd/src/memo/traits.rs @@ -39,6 +39,45 @@ pub trait Representative { ) -> PhysicalExpressionId; } +/// A helper trait to help facilitate the materialization and creation of objects in the memo table. +#[trait_variant::make(Send)] +pub trait Materialize { + /// Retrieves the ID of a [`Goal`]. If the [`Goal`] does not already exist in the memo table, + /// creates a new [`Goal`] and returns a fresh [`GoalId`]. + async fn get_goal_id(&mut self, goal: &Goal) -> MemoResult; + + /// Materializes a [`Goal`] from its [`GoalId`]. + async fn materialize_goal(&self, goal_id: GoalId) -> MemoResult; + + /// Retrieves the ID of a [`LogicalExpression`]. If the [`LogicalExpression`] does not already + /// exist in the memo table, creates a new [`LogicalExpression`] and returns a fresh + /// [`LogicalExpressionId`]. + async fn get_logical_expr_id( + &mut self, + logical_expr: &LogicalExpression, + ) -> MemoResult; + + /// Materializes a [`LogicalExpression`] from its [`LogicalExpressionId`]. + async fn materialize_logical_expr( + &self, + logical_expr_id: LogicalExpressionId, + ) -> MemoResult; + + /// Retrieves the ID of a [`PhysicalExpression`]. If the [`PhysicalExpression`] does not already + /// exist in the memo table, creates a new [`PhysicalExpression`] and returns a fresh + /// [`PhysicalExpressionId`]. + async fn get_physical_expr_id( + &mut self, + physical_expr: &PhysicalExpression, + ) -> MemoResult; + + /// Materializes a [`PhysicalExpression`] from its [`PhysicalExpressionId`]. + async fn materialize_physical_expr( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoResult; +} + /// Core interface for memo-based query optimization. /// /// This trait defines the operations needed to store, retrieve, and manipulate @@ -317,74 +356,4 @@ pub trait Memoize: Representative + Sync { physical_expr_id: PhysicalExpressionId, goal_id: GoalId, ) -> MemoResult<()>; - - // - // ID conversion and materialization operations. - // - - /// Gets or creates a goal ID for a given goal. - /// - /// # Parameters - /// * `goal` - The goal to get or create an ID for. - /// - /// # Returns - /// The ID of the goal. - async fn get_goal_id(&mut self, goal: &Goal) -> MemoResult; - - /// Materializes a goal from its ID. - /// - /// # Parameters - /// * `goal_id` - ID of the goal to materialize. - /// - /// # Returns - /// The materialized goal. - async fn materialize_goal(&self, goal_id: GoalId) -> MemoResult; - - /// Gets or creates a logical expression ID for a given logical expression. - /// - /// # Parameters - /// * `logical_expr` - The logical expression to get or create an ID for. - /// - /// # Returns - /// The ID of the logical expression. - async fn get_logical_expr_id( - &mut self, - logical_expr: &LogicalExpression, - ) -> MemoResult; - - /// Materializes a logical expression from its ID. - /// - /// # Parameters - /// * `logical_expr_id` - ID of the logical expression to materialize. - /// - /// # Returns - /// The materialized logical expression. - async fn materialize_logical_expr( - &self, - logical_expr_id: LogicalExpressionId, - ) -> MemoResult; - - /// Gets or creates a physical expression ID for a given physical expression. - /// - /// # Parameters - /// * `physical_expr` - The physical expression to get or create an ID for. - /// - /// # Returns - /// The ID of the physical expression. - async fn get_physical_expr_id( - &mut self, - physical_expr: &PhysicalExpression, - ) -> MemoResult; - - /// Materializes a physical expression from its ID. - /// - /// # Parameters - /// * `physical_expr_id` - ID of the physical expression to materialize. - /// - /// # Returns - /// The materialized physical expression. - async fn materialize_physical_expr( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoResult; } From 26bec41c00bb734245e692d611f038ce53a064b7 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 2 May 2025 12:56:39 -0400 Subject: [PATCH 3/8] break out task graph memo operations into new trait --- optd/src/core/optimizer/handlers.rs | 2 +- optd/src/core/optimizer/merge.rs | 2 +- optd/src/core/optimizer/mod.rs | 17 +- optd/src/memo/memory.rs | 382 ++++++++++++++-------------- optd/src/memo/traits.rs | 24 +- 5 files changed, 210 insertions(+), 217 deletions(-) diff --git a/optd/src/core/optimizer/handlers.rs b/optd/src/core/optimizer/handlers.rs index 831dc62d..e0b7b15c 100644 --- a/optd/src/core/optimizer/handlers.rs +++ b/optd/src/core/optimizer/handlers.rs @@ -113,7 +113,7 @@ impl Optimizer { goal_id: GoalId, _task_id: TaskId, ) -> Result<(), Error> { - let goal_id = self.memo.find_repr_goal(goal_id).await?; + let goal_id = self.memo.find_repr_goal(goal_id).await; let member = self.ingest_physical_plan(&plan).await?; diff --git a/optd/src/core/optimizer/merge.rs b/optd/src/core/optimizer/merge.rs index 1a3f1a40..73f54287 100644 --- a/optd/src/core/optimizer/merge.rs +++ b/optd/src/core/optimizer/merge.rs @@ -362,7 +362,7 @@ impl Optimizer { let task = self.tasks.get(task_id).unwrap().as_transform_expression(); let logical_expr_id = task.logical_expr_id; let repr_logical_expr_id = - self.memo.find_repr_logical_expr(logical_expr_id).await?; + self.memo.find_repr_logical_expr(logical_expr_id).await; exprs_to_trans_tasks .entry(repr_logical_expr_id) .or_insert_with(HashMap::new) diff --git a/optd/src/core/optimizer/mod.rs b/optd/src/core/optimizer/mod.rs index ed6c47f4..bcb5033f 100644 --- a/optd/src/core/optimizer/mod.rs +++ b/optd/src/core/optimizer/mod.rs @@ -1,8 +1,5 @@ use crate::catalog::Catalog; -use crate::core::cir::{ - Cost, Goal, GoalId, GroupId, LogicalProperties, PartialLogicalPlan, PartialPhysicalPlan, - PhysicalExpressionId, RuleBook, -}; +use crate::core::cir::*; use crate::core::error::Error; use crate::dsl::analyzer::hir::context::Context; use crate::memo::Memoize; @@ -287,12 +284,6 @@ impl Optimizer { #[cfg(test)] mod tests { - - use std::{sync::Arc, time::Duration}; - - use async_trait::async_trait; - use tokio::task::JoinSet; - use super::*; use crate::{ catalog::CatalogError, @@ -300,8 +291,12 @@ mod tests { Child, GoalMemberId, LogicalExpression, LogicalPlan, Operator, OperatorData, PhysicalExpression, PhysicalPlan, PhysicalProperties, PropertiesData, }, - memo::memory::MemoryMemo, + memo::{Materialize, memory::MemoryMemo}, }; + use async_trait::async_trait; + use std::{sync::Arc, time::Duration}; + use tokio::task::JoinSet; + #[derive(Debug)] struct MockCatalog; diff --git a/optd/src/memo/memory.rs b/optd/src/memo/memory.rs index 0d3ee559..14ed9d3c 100644 --- a/optd/src/memo/memory.rs +++ b/optd/src/memo/memory.rs @@ -487,196 +487,6 @@ impl Memoize for MemoryMemo { Ok(None) } } - - async fn get_transformation_status( - &self, - logical_expr_id: LogicalExpressionId, - rule: &TransformationRule, - ) -> MemoResult { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; - let status = self - .transform_dependency - .get(&logical_expr_id) - .and_then(|status_map| status_map.get(rule)) - .map(|dep| dep.status) - .unwrap_or(TaskStatus::Dirty); - Ok(status) - } - - async fn set_transformation_clean( - &mut self, - logical_expr_id: LogicalExpressionId, - rule: &TransformationRule, - ) -> MemoResult<()> { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; - let status_map = self - .transform_dependency - .entry(logical_expr_id) - .or_default(); - match status_map.entry(rule.clone()) { - Entry::Occupied(occupied_entry) => { - let dep = occupied_entry.into_mut(); - dep.status = TaskStatus::Clean; - } - Entry::Vacant(vacant) => { - vacant.insert(RuleDependency::new(TaskStatus::Clean)); - } - } - Ok(()) - } - - async fn get_implementation_status( - &self, - logical_expr_id: LogicalExpressionId, - goal_id: GoalId, - rule: &ImplementationRule, - ) -> MemoResult { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; - let goal_id = self.find_repr_goal(goal_id).await; - let status = self - .implement_dependency - .get(&logical_expr_id) - .and_then(|status_map| status_map.get(&(goal_id, rule.clone()))) - .map(|dep| dep.status) - .unwrap_or(TaskStatus::Dirty); - Ok(status) - } - - async fn set_implementation_clean( - &mut self, - logical_expr_id: LogicalExpressionId, - goal_id: GoalId, - rule: &ImplementationRule, - ) -> MemoResult<()> { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; - let status_map = self - .implement_dependency - .entry(logical_expr_id) - .or_default(); - match status_map.entry((goal_id, rule.clone())) { - Entry::Occupied(occupied_entry) => { - let dep = occupied_entry.into_mut(); - dep.status = TaskStatus::Clean; - } - Entry::Vacant(vacant) => { - vacant.insert(RuleDependency::new(TaskStatus::Clean)); - } - } - Ok(()) - } - - async fn get_cost_status( - &self, - physical_expr_id: PhysicalExpressionId, - ) -> MemoResult { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; - let status = self - .cost_dependency - .get(&physical_expr_id) - .map(|dep| dep.status) - .unwrap_or(TaskStatus::Dirty); - Ok(status) - } - - async fn set_cost_clean(&mut self, physical_expr_id: PhysicalExpressionId) -> MemoResult<()> { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; - - let entry = self.cost_dependency.entry(physical_expr_id); - - match entry { - Entry::Occupied(occupied) => { - let dep = occupied.into_mut(); - dep.status = TaskStatus::Clean; - } - Entry::Vacant(vacant) => { - vacant.insert(CostDependency::new(TaskStatus::Clean)); - } - } - - Ok(()) - } - - async fn add_transformation_dependency( - &mut self, - logical_expr_id: LogicalExpressionId, - rule: &TransformationRule, - group_id: GroupId, - ) -> MemoResult<()> { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; - let group_id = self.find_repr_group(group_id).await; - let status_map = self - .transform_dependency - .entry(logical_expr_id) - .or_default(); - - match status_map.entry(rule.clone()) { - Entry::Occupied(occupied_entry) => { - let dep = occupied_entry.into_mut(); - dep.group_ids.insert(group_id); - } - Entry::Vacant(vacant) => { - let mut dep = RuleDependency::new(TaskStatus::Dirty); - dep.group_ids.insert(group_id); - vacant.insert(dep); - } - } - - Ok(()) - } - - async fn add_implementation_dependency( - &mut self, - logical_expr_id: LogicalExpressionId, - goal_id: GoalId, - rule: &ImplementationRule, - group_id: GroupId, - ) -> MemoResult<()> { - let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; - let group_id = self.find_repr_group(group_id).await; - let goal_id = self.find_repr_goal(goal_id).await; - - let status_map = self - .implement_dependency - .entry(logical_expr_id) - .or_default(); - - match status_map.entry((goal_id, rule.clone())) { - Entry::Occupied(occupied) => { - let dep = occupied.into_mut(); - dep.group_ids.insert(group_id); - } - Entry::Vacant(vacant) => { - let mut dep = RuleDependency::new(TaskStatus::Dirty); - dep.group_ids.insert(group_id); - vacant.insert(dep); - } - } - - Ok(()) - } - - async fn add_cost_dependency( - &mut self, - physical_expr_id: PhysicalExpressionId, - goal_id: GoalId, - ) -> MemoResult<()> { - let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; - let goal_id = self.find_repr_goal(goal_id).await; - - match self.cost_dependency.entry(physical_expr_id) { - Entry::Occupied(occupied) => { - let dep = occupied.into_mut(); - dep.goal_ids.insert(goal_id); - } - Entry::Vacant(vacant) => { - let mut dep = CostDependency::new(TaskStatus::Dirty); - dep.goal_ids.insert(goal_id); - vacant.insert(dep); - } - } - - Ok(()) - } } impl MemoryMemo { @@ -1131,3 +941,195 @@ impl MemoryMemo { } } } + +impl TaskGraphState for MemoryMemo { + async fn get_transformation_status( + &self, + logical_expr_id: LogicalExpressionId, + rule: &TransformationRule, + ) -> MemoResult { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let status = self + .transform_dependency + .get(&logical_expr_id) + .and_then(|status_map| status_map.get(rule)) + .map(|dep| dep.status) + .unwrap_or(TaskStatus::Dirty); + Ok(status) + } + + async fn set_transformation_clean( + &mut self, + logical_expr_id: LogicalExpressionId, + rule: &TransformationRule, + ) -> MemoResult<()> { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let status_map = self + .transform_dependency + .entry(logical_expr_id) + .or_default(); + match status_map.entry(rule.clone()) { + Entry::Occupied(occupied_entry) => { + let dep = occupied_entry.into_mut(); + dep.status = TaskStatus::Clean; + } + Entry::Vacant(vacant) => { + vacant.insert(RuleDependency::new(TaskStatus::Clean)); + } + } + Ok(()) + } + + async fn get_implementation_status( + &self, + logical_expr_id: LogicalExpressionId, + goal_id: GoalId, + rule: &ImplementationRule, + ) -> MemoResult { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let goal_id = self.find_repr_goal(goal_id).await; + let status = self + .implement_dependency + .get(&logical_expr_id) + .and_then(|status_map| status_map.get(&(goal_id, rule.clone()))) + .map(|dep| dep.status) + .unwrap_or(TaskStatus::Dirty); + Ok(status) + } + + async fn set_implementation_clean( + &mut self, + logical_expr_id: LogicalExpressionId, + goal_id: GoalId, + rule: &ImplementationRule, + ) -> MemoResult<()> { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let status_map = self + .implement_dependency + .entry(logical_expr_id) + .or_default(); + match status_map.entry((goal_id, rule.clone())) { + Entry::Occupied(occupied_entry) => { + let dep = occupied_entry.into_mut(); + dep.status = TaskStatus::Clean; + } + Entry::Vacant(vacant) => { + vacant.insert(RuleDependency::new(TaskStatus::Clean)); + } + } + Ok(()) + } + + async fn get_cost_status( + &self, + physical_expr_id: PhysicalExpressionId, + ) -> MemoResult { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; + let status = self + .cost_dependency + .get(&physical_expr_id) + .map(|dep| dep.status) + .unwrap_or(TaskStatus::Dirty); + Ok(status) + } + + async fn set_cost_clean(&mut self, physical_expr_id: PhysicalExpressionId) -> MemoResult<()> { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; + + let entry = self.cost_dependency.entry(physical_expr_id); + + match entry { + Entry::Occupied(occupied) => { + let dep = occupied.into_mut(); + dep.status = TaskStatus::Clean; + } + Entry::Vacant(vacant) => { + vacant.insert(CostDependency::new(TaskStatus::Clean)); + } + } + + Ok(()) + } + + async fn add_transformation_dependency( + &mut self, + logical_expr_id: LogicalExpressionId, + rule: &TransformationRule, + group_id: GroupId, + ) -> MemoResult<()> { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let group_id = self.find_repr_group(group_id).await; + let status_map = self + .transform_dependency + .entry(logical_expr_id) + .or_default(); + + match status_map.entry(rule.clone()) { + Entry::Occupied(occupied_entry) => { + let dep = occupied_entry.into_mut(); + dep.group_ids.insert(group_id); + } + Entry::Vacant(vacant) => { + let mut dep = RuleDependency::new(TaskStatus::Dirty); + dep.group_ids.insert(group_id); + vacant.insert(dep); + } + } + + Ok(()) + } + + async fn add_implementation_dependency( + &mut self, + logical_expr_id: LogicalExpressionId, + goal_id: GoalId, + rule: &ImplementationRule, + group_id: GroupId, + ) -> MemoResult<()> { + let logical_expr_id = self.find_repr_logical_expr(logical_expr_id).await; + let group_id = self.find_repr_group(group_id).await; + let goal_id = self.find_repr_goal(goal_id).await; + + let status_map = self + .implement_dependency + .entry(logical_expr_id) + .or_default(); + + match status_map.entry((goal_id, rule.clone())) { + Entry::Occupied(occupied) => { + let dep = occupied.into_mut(); + dep.group_ids.insert(group_id); + } + Entry::Vacant(vacant) => { + let mut dep = RuleDependency::new(TaskStatus::Dirty); + dep.group_ids.insert(group_id); + vacant.insert(dep); + } + } + + Ok(()) + } + + async fn add_cost_dependency( + &mut self, + physical_expr_id: PhysicalExpressionId, + goal_id: GoalId, + ) -> MemoResult<()> { + let physical_expr_id = self.find_repr_physical_expr(physical_expr_id).await; + let goal_id = self.find_repr_goal(goal_id).await; + + match self.cost_dependency.entry(physical_expr_id) { + Entry::Occupied(occupied) => { + let dep = occupied.into_mut(); + dep.goal_ids.insert(goal_id); + } + Entry::Vacant(vacant) => { + let mut dep = CostDependency::new(TaskStatus::Dirty); + dep.goal_ids.insert(goal_id); + vacant.insert(dep); + } + } + + Ok(()) + } +} diff --git a/optd/src/memo/traits.rs b/optd/src/memo/traits.rs index be74f573..cf7b0c6c 100644 --- a/optd/src/memo/traits.rs +++ b/optd/src/memo/traits.rs @@ -78,18 +78,12 @@ pub trait Materialize { ) -> MemoResult; } -/// Core interface for memo-based query optimization. +/// The interface for an optimizer memoization (memo) table. /// -/// This trait defines the operations needed to store, retrieve, and manipulate -/// the memo data structure that powers the dynamic programming approach to -/// query optimization. The memo stores logical and physical expressions by their IDs, -/// manages expression properties, and tracks optimization status. +/// This trait mainly describes operations related to groups, goals, logical and physical +/// expressions, and finding representative nodes of the union-find substructures. #[trait_variant::make(Send)] -pub trait Memoize: Representative + Sync { - // - // Logical expression and group operations. - // - +pub trait Memoize: Representative + Materialize + TaskGraphState + Sync + 'static { /// Retrieves logical properties for a group ID. /// /// # Parameters @@ -228,11 +222,13 @@ pub trait Memoize: Representative + Sync { &self, physical_expr_id: PhysicalExpressionId, ) -> MemoResult>; +} - // - // Rule and costing status operations. - // - +/// Rule and costing status operations. +/// +/// TODO(connor): Clean up docs. +#[trait_variant::make(Send)] +pub trait TaskGraphState { /// Checks the status of applying a transformation rule on a logical expression ID. /// /// # Parameters From 68ffd7f0fc3ccc06a375843728f56f580f1ab608 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 2 May 2025 12:58:56 -0400 Subject: [PATCH 4/8] rename Memoize to Memo --- optd/src/core/optimizer/client.rs | 6 +++--- optd/src/core/optimizer/egest.rs | 4 ++-- optd/src/core/optimizer/forward.rs | 4 ++-- optd/src/core/optimizer/handlers.rs | 4 ++-- optd/src/core/optimizer/ingest.rs | 4 ++-- optd/src/core/optimizer/jobs.rs | 4 ++-- optd/src/core/optimizer/merge.rs | 7 +++---- optd/src/core/optimizer/mod.rs | 6 +++--- optd/src/core/optimizer/tasks/continue_with_costed.rs | 4 ++-- optd/src/core/optimizer/tasks/continue_with_logical.rs | 4 ++-- optd/src/core/optimizer/tasks/cost_expr.rs | 4 ++-- optd/src/core/optimizer/tasks/explore_group.rs | 4 ++-- optd/src/core/optimizer/tasks/fork_costed.rs | 4 ++-- optd/src/core/optimizer/tasks/fork_logical.rs | 4 ++-- optd/src/core/optimizer/tasks/implement_expr.rs | 4 ++-- optd/src/core/optimizer/tasks/mod.rs | 4 ++-- optd/src/core/optimizer/tasks/optimize_goal.rs | 4 ++-- optd/src/core/optimizer/tasks/optimize_plan.rs | 4 ++-- optd/src/core/optimizer/tasks/transform_expr.rs | 4 ++-- optd/src/memo/memory.rs | 2 +- optd/src/memo/traits.rs | 2 +- 21 files changed, 43 insertions(+), 44 deletions(-) diff --git a/optd/src/core/optimizer/client.rs b/optd/src/core/optimizer/client.rs index 2ca922ad..0b462c76 100644 --- a/optd/src/core/optimizer/client.rs +++ b/optd/src/core/optimizer/client.rs @@ -6,7 +6,7 @@ use futures::{ use crate::{ core::cir::{LogicalPlan, PhysicalPlan}, core::error::Error, - memo::Memoize, + memo::Memo, }; /// Unique identifier for a query instance. @@ -57,12 +57,12 @@ impl Drop for QueryInstance { } } -pub struct Client { +pub struct Client { tx: mpsc::Sender, handle: tokio::task::JoinHandle, } -impl Client { +impl Client { pub fn new(tx: mpsc::Sender, handle: tokio::task::JoinHandle) -> Self { Self { tx, handle } } diff --git a/optd/src/core/optimizer/egest.rs b/optd/src/core/optimizer/egest.rs index 366ee87e..88fde2fb 100644 --- a/optd/src/core/optimizer/egest.rs +++ b/optd/src/core/optimizer/egest.rs @@ -8,12 +8,12 @@ use crate::{ Child, GoalMemberId, Operator, PartialPhysicalPlan, PhysicalExpressionId, PhysicalPlan, }, core::error::Error, - memo::Memoize, + memo::Memo, }; use super::Optimizer; -impl Optimizer { +impl Optimizer { /// Recursively transforms a physical expression ID in the memo into a complete physical plan. /// /// This function retrieves the physical expression from the memo and recursively diff --git a/optd/src/core/optimizer/forward.rs b/optd/src/core/optimizer/forward.rs index 7d9835bd..0571db1c 100644 --- a/optd/src/core/optimizer/forward.rs +++ b/optd/src/core/optimizer/forward.rs @@ -1,11 +1,11 @@ use crate::{ core::error::Error, - memo::{Memoize, PropagateBestExpression}, + memo::{Memo, PropagateBestExpression}, }; use super::Optimizer; -impl Optimizer { +impl Optimizer { pub(super) async fn handle_forward_result( &mut self, result: PropagateBestExpression, diff --git a/optd/src/core/optimizer/handlers.rs b/optd/src/core/optimizer/handlers.rs index e0b7b15c..e0f3482a 100644 --- a/optd/src/core/optimizer/handlers.rs +++ b/optd/src/core/optimizer/handlers.rs @@ -7,7 +7,7 @@ use crate::{ PartialPhysicalPlan, PhysicalExpressionId, }, core::error::Error, - memo::Memoize, + memo::Memo, }; use crate::dsl::{ @@ -16,7 +16,7 @@ use crate::dsl::{ }; use futures::{SinkExt, channel::mpsc::Sender}; -impl Optimizer { +impl Optimizer { /// This method initiates the optimization process for a logical plan by launching /// an optimization task. It may need dependencies. /// diff --git a/optd/src/core/optimizer/ingest.rs b/optd/src/core/optimizer/ingest.rs index b02e1f8b..3b0a9b91 100644 --- a/optd/src/core/optimizer/ingest.rs +++ b/optd/src/core/optimizer/ingest.rs @@ -5,12 +5,12 @@ use crate::{ PartialPhysicalPlan, PhysicalExpression, }, core::error::Error, - memo::Memoize, + memo::Memo, }; use Child::*; use std::sync::Arc; -impl Optimizer { +impl Optimizer { /// Ingests a logical plan into the memo. /// /// Returns the group id of the root logical expression. diff --git a/optd/src/core/optimizer/jobs.rs b/optd/src/core/optimizer/jobs.rs index c7cde320..5ad797f0 100644 --- a/optd/src/core/optimizer/jobs.rs +++ b/optd/src/core/optimizer/jobs.rs @@ -10,7 +10,7 @@ use crate::core::cir::{ }; use crate::core::error::Error; use crate::dsl::engine::{Engine, EngineResponse}; -use crate::memo::Memoize; +use crate::memo::Memo; use EngineMessageKind::*; use futures::SinkExt; use futures::channel::mpsc; @@ -47,7 +47,7 @@ pub enum Job { Derive(GroupId), } -impl Optimizer { +impl Optimizer { // // Job Scheduling and Management // diff --git a/optd/src/core/optimizer/merge.rs b/optd/src/core/optimizer/merge.rs index 73f54287..8280d6c7 100644 --- a/optd/src/core/optimizer/merge.rs +++ b/optd/src/core/optimizer/merge.rs @@ -7,10 +7,10 @@ use crate::{ core::cir::{GoalId, GroupId, ImplementationRule, LogicalExpressionId, TransformationRule}, core::error::Error, core::optimizer::tasks::TaskId, - memo::{Memoize, MergeProducts}, + memo::{Memo, MergeProducts}, }; -impl Optimizer { +impl Optimizer { /// Recursively deletes tasks that are no longer needed. /// Confirms before deletion that the task is not subscribed to by any other task by checking if the parent tasks (the outs) exist in the task index. /// If the parent task is not found in the task index, then the given task is safely deleted. Else, we do not delete the task. @@ -361,8 +361,7 @@ impl Optimizer { { let task = self.tasks.get(task_id).unwrap().as_transform_expression(); let logical_expr_id = task.logical_expr_id; - let repr_logical_expr_id = - self.memo.find_repr_logical_expr(logical_expr_id).await; + let repr_logical_expr_id = self.memo.find_repr_logical_expr(logical_expr_id).await; exprs_to_trans_tasks .entry(repr_logical_expr_id) .or_insert_with(HashMap::new) diff --git a/optd/src/core/optimizer/mod.rs b/optd/src/core/optimizer/mod.rs index bcb5033f..5272db0c 100644 --- a/optd/src/core/optimizer/mod.rs +++ b/optd/src/core/optimizer/mod.rs @@ -2,7 +2,7 @@ use crate::catalog::Catalog; use crate::core::cir::*; use crate::core::error::Error; use crate::dsl::analyzer::hir::context::Context; -use crate::memo::Memoize; +use crate::memo::Memo; use EngineMessageKind::*; pub use client::{Client, QueryInstance}; use client::{ClientMessage, QueryInstanceId}; @@ -87,7 +87,7 @@ pub enum EngineMessageKind { /// /// Provides the interface to submit logical plans for optimization and receive /// optimized physical plans in return. -pub struct Optimizer { +pub struct Optimizer { // Core components. memo: M, rule_book: RuleBook, @@ -131,7 +131,7 @@ pub struct Optimizer { cost_expression_task_index: HashMap, } -impl Optimizer { +impl Optimizer { /// Create a new optimizer instance with the given memo and HIR context. /// /// Use `launch` to create and start the optimizer. diff --git a/optd/src/core/optimizer/tasks/continue_with_costed.rs b/optd/src/core/optimizer/tasks/continue_with_costed.rs index 8c974724..ae4b2511 100644 --- a/optd/src/core/optimizer/tasks/continue_with_costed.rs +++ b/optd/src/core/optimizer/tasks/continue_with_costed.rs @@ -3,7 +3,7 @@ use crate::{ core::cir::{Cost, PhysicalExpressionId}, core::error::Error, core::optimizer::{JobId, Optimizer}, - memo::Memoize, + memo::Memo, }; use super::{Task, TaskId}; @@ -31,7 +31,7 @@ impl ContinueWithCostedTask { } } -impl Optimizer { +impl Optimizer { pub async fn create_continue_with_costed_task( &mut self, physical_expr_id: PhysicalExpressionId, diff --git a/optd/src/core/optimizer/tasks/continue_with_logical.rs b/optd/src/core/optimizer/tasks/continue_with_logical.rs index 4d1c5d50..e1af4216 100644 --- a/optd/src/core/optimizer/tasks/continue_with_logical.rs +++ b/optd/src/core/optimizer/tasks/continue_with_logical.rs @@ -3,7 +3,7 @@ use crate::{ core::cir::LogicalExpressionId, core::error::Error, core::optimizer::{JobId, Optimizer}, - memo::Memoize, + memo::Memo, }; use super::{Task, TaskId}; @@ -29,7 +29,7 @@ impl ContinueWithLogicalTask { } } -impl Optimizer { +impl Optimizer { /// Creates a `ContinueWithLogical` task. pub(crate) async fn create_continue_with_logical_task( &mut self, diff --git a/optd/src/core/optimizer/tasks/cost_expr.rs b/optd/src/core/optimizer/tasks/cost_expr.rs index 83b95293..0d032441 100644 --- a/optd/src/core/optimizer/tasks/cost_expr.rs +++ b/optd/src/core/optimizer/tasks/cost_expr.rs @@ -7,7 +7,7 @@ use crate::{ core::cir::{Cost, PhysicalExpressionId}, core::error::Error, core::optimizer::{EngineMessageKind, JobId, Optimizer, Task}, - memo::{Memoize, TaskStatus}, + memo::{Memo, TaskStatus}, }; use super::TaskId; @@ -57,7 +57,7 @@ impl CostExpressionTask { } } -impl Optimizer { +impl Optimizer { /// Ensures a cost expression task exists and sets up a parent-child relationship. /// /// This is used when a task needs to cost a physical expression as part of its work. diff --git a/optd/src/core/optimizer/tasks/explore_group.rs b/optd/src/core/optimizer/tasks/explore_group.rs index 9f42765b..d08449e5 100644 --- a/optd/src/core/optimizer/tasks/explore_group.rs +++ b/optd/src/core/optimizer/tasks/explore_group.rs @@ -2,7 +2,7 @@ use crate::{ core::cir::{GroupId, LogicalExpressionId}, core::error::Error, core::optimizer::Optimizer, - memo::Memoize, + memo::Memo, }; use super::{SourceTaskId, Task, TaskId}; @@ -45,7 +45,7 @@ impl ExploreGroupTask { } } -impl Optimizer { +impl Optimizer { // Creates the `ExploreGroup` task if it does not exist and get all current logical expressions. pub async fn ensure_explore_group_task( &mut self, diff --git a/optd/src/core/optimizer/tasks/fork_costed.rs b/optd/src/core/optimizer/tasks/fork_costed.rs index b94303e1..cbcf3162 100644 --- a/optd/src/core/optimizer/tasks/fork_costed.rs +++ b/optd/src/core/optimizer/tasks/fork_costed.rs @@ -7,7 +7,7 @@ use crate::{ core::cir::{Cost, GoalId}, core::error::Error, core::optimizer::{EngineMessageKind, Optimizer}, - memo::Memoize, + memo::Memo, }; use super::{SourceTaskId, Task, TaskId}; @@ -49,7 +49,7 @@ impl ForkCostedTask { } } -impl Optimizer { +impl Optimizer { pub(crate) async fn create_fork_costed_task( &mut self, goal_id: GoalId, diff --git a/optd/src/core/optimizer/tasks/fork_logical.rs b/optd/src/core/optimizer/tasks/fork_logical.rs index 23fb1ef5..4bb6ff49 100644 --- a/optd/src/core/optimizer/tasks/fork_logical.rs +++ b/optd/src/core/optimizer/tasks/fork_logical.rs @@ -6,7 +6,7 @@ use crate::dsl::{ use crate::{ core::cir::GroupId, core::optimizer::{EngineMessageKind, Optimizer}, - memo::Memoize, + memo::Memo, }; use super::{SourceTaskId, Task, TaskId}; @@ -41,7 +41,7 @@ impl ForkLogicalTask { } } -impl Optimizer { +impl Optimizer { /// Creates a task to fork the logical plan for further exploration. /// /// This task generates alternative logical expressions that are diff --git a/optd/src/core/optimizer/tasks/implement_expr.rs b/optd/src/core/optimizer/tasks/implement_expr.rs index 9d865140..e7d84918 100644 --- a/optd/src/core/optimizer/tasks/implement_expr.rs +++ b/optd/src/core/optimizer/tasks/implement_expr.rs @@ -10,7 +10,7 @@ use crate::{ core::cir::{Goal, GoalId, ImplementationRule, LogicalExpressionId}, core::error::Error, core::optimizer::{EngineMessageKind, JobId, Optimizer, Task}, - memo::{Memoize, TaskStatus}, + memo::{Memo, TaskStatus}, }; use super::TaskId; @@ -59,7 +59,7 @@ impl ImplementExpressionTask { } } -impl Optimizer { +impl Optimizer { pub(crate) async fn create_implement_expression_task( &mut self, rule: ImplementationRule, diff --git a/optd/src/core/optimizer/tasks/mod.rs b/optd/src/core/optimizer/tasks/mod.rs index 77c29a97..31c633ce 100644 --- a/optd/src/core/optimizer/tasks/mod.rs +++ b/optd/src/core/optimizer/tasks/mod.rs @@ -20,7 +20,7 @@ use optimize_goal::*; use optimize_plan::*; use transform_expr::*; -use crate::memo::Memoize; +use crate::memo::Memo; use super::Optimizer; @@ -85,7 +85,7 @@ pub enum Task { ContinueWithCosted(ContinueWithCostedTask), } -impl Optimizer { +impl Optimizer { pub fn get_task(&self, task_id: TaskId) -> &Task { self.tasks.get(&task_id).unwrap() } diff --git a/optd/src/core/optimizer/tasks/optimize_goal.rs b/optd/src/core/optimizer/tasks/optimize_goal.rs index 530985f3..d5fdedb8 100644 --- a/optd/src/core/optimizer/tasks/optimize_goal.rs +++ b/optd/src/core/optimizer/tasks/optimize_goal.rs @@ -4,7 +4,7 @@ use crate::{ core::cir::{Cost, Goal, GoalId, GoalMemberId, PhysicalExpressionId}, core::error::Error, core::optimizer::{Optimizer, Task}, - memo::Memoize, + memo::Memo, }; use super::{SourceTaskId, TaskId}; @@ -67,7 +67,7 @@ impl OptimizeGoalTask { } } -impl Optimizer { +impl Optimizer { pub async fn ensure_optimize_goal_task( &mut self, goal_id: GoalId, diff --git a/optd/src/core/optimizer/tasks/optimize_plan.rs b/optd/src/core/optimizer/tasks/optimize_plan.rs index 65544a5a..7bf1962f 100644 --- a/optd/src/core/optimizer/tasks/optimize_plan.rs +++ b/optd/src/core/optimizer/tasks/optimize_plan.rs @@ -4,7 +4,7 @@ use crate::{ core::cir::{Goal, LogicalPlan, PhysicalExpressionId, PhysicalPlan, PhysicalProperties}, core::error::Error, core::optimizer::{Optimizer, tasks::SourceTaskId}, - memo::Memoize, + memo::Memo, }; use super::{Task, TaskId}; @@ -36,7 +36,7 @@ impl OptimizePlanTask { } } -impl Optimizer { +impl Optimizer { pub async fn emit_best_physical_plan( &mut self, mut physical_plan_tx: mpsc::Sender, diff --git a/optd/src/core/optimizer/tasks/transform_expr.rs b/optd/src/core/optimizer/tasks/transform_expr.rs index 49acba94..fa41aae6 100644 --- a/optd/src/core/optimizer/tasks/transform_expr.rs +++ b/optd/src/core/optimizer/tasks/transform_expr.rs @@ -7,7 +7,7 @@ use crate::{ core::cir::{LogicalExpressionId, PartialLogicalPlan, TransformationRule}, core::error::Error, core::optimizer::{EngineMessageKind, JobId, Optimizer, Task}, - memo::{Memoize, TaskStatus}, + memo::{Memo, TaskStatus}, }; use super::TaskId; @@ -52,7 +52,7 @@ impl TransformExpressionTask { } } -impl Optimizer { +impl Optimizer { /// Creates a task to start applying a transformation rule to a logical expression. /// /// This task generates alternative logical expressions that are diff --git a/optd/src/memo/memory.rs b/optd/src/memo/memory.rs index 14ed9d3c..0e3cf625 100644 --- a/optd/src/memo/memory.rs +++ b/optd/src/memo/memory.rs @@ -280,7 +280,7 @@ impl Materialize for MemoryMemo { } } -impl Memoize for MemoryMemo { +impl Memo for MemoryMemo { async fn merge_groups( &mut self, group_id_1: GroupId, diff --git a/optd/src/memo/traits.rs b/optd/src/memo/traits.rs index cf7b0c6c..3bcd23ca 100644 --- a/optd/src/memo/traits.rs +++ b/optd/src/memo/traits.rs @@ -83,7 +83,7 @@ pub trait Materialize { /// This trait mainly describes operations related to groups, goals, logical and physical /// expressions, and finding representative nodes of the union-find substructures. #[trait_variant::make(Send)] -pub trait Memoize: Representative + Materialize + TaskGraphState + Sync + 'static { +pub trait Memo: Representative + Materialize + TaskGraphState + Sync + 'static { /// Retrieves logical properties for a group ID. /// /// # Parameters From 516b22b3e8d2d7a4857accee938e451160b019ee Mon Sep 17 00:00:00 2001 From: Connor Tsui <87130162+connortsui20@users.noreply.github.com> Date: Fri, 2 May 2025 13:03:47 -0400 Subject: [PATCH 5/8] `optd-cli` crate (#95) ## Problem If we want users of the DSL to add their own UDFs, it makes more sense to have them recompile just the CLI with their functions rather than the entirety of `optd`. ## Summary of changes Extracts the `cli` module into its own crate. Ideally we figure out how to let users of the CLI add their own UDFs in an ergonomic way without having to fork our repository. This should be easy (could either use a proc macro or expose more of the compilation functions as building blocks in a library), and can be left to a later point in time. --- Cargo.lock | 9 +++++++++ Cargo.toml | 2 +- optd-cli/Cargo.toml | 11 +++++++++++ {optd/src/dsl => optd-cli}/examples/generics.opt | 2 +- {optd/src/dsl => optd-cli}/examples/higher_order.opt | 2 +- {optd/src/dsl => optd-cli}/examples/lists.opt | 2 +- {optd/src/dsl => optd-cli}/examples/tutorial.opt | 2 +- {optd/src/dsl/cli => optd-cli/src}/main.rs | 3 +-- optd/Cargo.toml | 4 ---- 9 files changed, 26 insertions(+), 11 deletions(-) create mode 100644 optd-cli/Cargo.toml rename {optd/src/dsl => optd-cli}/examples/generics.opt (99%) rename {optd/src/dsl => optd-cli}/examples/higher_order.opt (99%) rename {optd/src/dsl => optd-cli}/examples/lists.opt (92%) rename {optd/src/dsl => optd-cli}/examples/tutorial.opt (99%) rename {optd/src/dsl/cli => optd-cli/src}/main.rs (99%) diff --git a/Cargo.lock b/Cargo.lock index ca4487ec..25bb15fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2013,6 +2013,15 @@ dependencies = [ "trait-variant", ] +[[package]] +name = "optd-cli" +version = "0.1.0" +dependencies = [ + "clap", + "colored", + "optd", +] + [[package]] name = "ordered-float" version = "2.10.1" diff --git a/Cargo.toml b/Cargo.toml index 05442ef7..484cf790 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["optd"] +members = ["optd", "optd-cli"] resolver = "2" [workspace.package] diff --git a/optd-cli/Cargo.toml b/optd-cli/Cargo.toml new file mode 100644 index 00000000..030f0d69 --- /dev/null +++ b/optd-cli/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "optd-cli" +version.workspace = true +edition.workspace = true +repository.workspace = true + +[dependencies] +optd = { path = "../optd" } + +clap = { version = "4.5.37", features = ["derive"] } +colored = "3.0.0" diff --git a/optd/src/dsl/examples/generics.opt b/optd-cli/examples/generics.opt similarity index 99% rename from optd/src/dsl/examples/generics.opt rename to optd-cli/examples/generics.opt index b8930472..004a7687 100644 --- a/optd/src/dsl/examples/generics.opt +++ b/optd-cli/examples/generics.opt @@ -61,4 +61,4 @@ fn main(): I64 = filtered = sorted.filter(x: I64 -> x > 2), result = memoized_factorial(7) in - result \ No newline at end of file + result diff --git a/optd/src/dsl/examples/higher_order.opt b/optd-cli/examples/higher_order.opt similarity index 99% rename from optd/src/dsl/examples/higher_order.opt rename to optd-cli/examples/higher_order.opt index 3b07d932..516ce6e1 100644 --- a/optd/src/dsl/examples/higher_order.opt +++ b/optd-cli/examples/higher_order.opt @@ -97,4 +97,4 @@ fn main(log: Logical): Logical = double_and_optimize(expr1) )( map_binary(expr2, (x: Logical) -> Add(x, Number(1))) - ) \ No newline at end of file + ) diff --git a/optd/src/dsl/examples/lists.opt b/optd-cli/examples/lists.opt similarity index 92% rename from optd/src/dsl/examples/lists.opt rename to optd-cli/examples/lists.opt index d2f98d27..9c891b1d 100644 --- a/optd/src/dsl/examples/lists.opt +++ b/optd-cli/examples/lists.opt @@ -24,4 +24,4 @@ fn main(log: Logical): Logical? = match log \ _ -> fail("brru") fn (log: [Logical]) vla: [Logical] = match log - \ [Const(5) .. [List([Const(5) .. _]) .. v]] -> v \ No newline at end of file + \ [Const(5) .. [List([Const(5) .. _]) .. v]] -> v diff --git a/optd/src/dsl/examples/tutorial.opt b/optd-cli/examples/tutorial.opt similarity index 99% rename from optd/src/dsl/examples/tutorial.opt rename to optd-cli/examples/tutorial.opt index 5c268755..957b94c1 100644 --- a/optd/src/dsl/examples/tutorial.opt +++ b/optd-cli/examples/tutorial.opt @@ -409,4 +409,4 @@ fn (log: Logical*) derive(): LogicalProperties = match log new_columns = [] // TODO: Some remapping code here. in LogicalProperties(Schema(new_columns)) - \ Sort(child, _) -> child.properties() \ No newline at end of file + \ Sort(child, _) -> child.properties() diff --git a/optd/src/dsl/cli/main.rs b/optd-cli/src/main.rs similarity index 99% rename from optd/src/dsl/cli/main.rs rename to optd-cli/src/main.rs index c53c0ffc..d3e0a6cf 100644 --- a/optd/src/dsl/cli/main.rs +++ b/optd-cli/src/main.rs @@ -28,14 +28,13 @@ //! cargo run --bin optd-cli -- compile path/to/example.opt --mock-udfs hello get_schema world //! ``` -use std::collections::HashMap; - use clap::{Parser, Subcommand}; use colored::Colorize; use optd::catalog::Catalog; use optd::dsl::analyzer::hir::{CoreData, Udf, Value}; use optd::dsl::compile::{Config, compile_hir}; use optd::dsl::utils::errors::{CompileError, Diagnose}; +use std::collections::HashMap; #[derive(Parser)] #[command( diff --git a/optd/Cargo.toml b/optd/Cargo.toml index 9cec381f..2e622312 100644 --- a/optd/Cargo.toml +++ b/optd/Cargo.toml @@ -4,10 +4,6 @@ version = "0.1.0" edition = "2024" repository = "https://github.com/cmu-db/optd" -[[bin]] -name = "optd-cli" -path = "src/dsl/cli/main.rs" - [dependencies] ariadne = "0.5.1" async-recursion = "1.1.1" From 93acd84266420cea9460c13cbb95e29251e98ae4 Mon Sep 17 00:00:00 2001 From: AlSchlo <79570602+AlSchlo@users.noreply.github.com> Date: Sat, 3 May 2025 17:10:10 +0200 Subject: [PATCH 6/8] (feat) Add Type Bounds (#98) Supports generic bounds in the type system. Partially tackles #81 . ``` fn (pairs: [(K, V)]) to_map(): {K: V} = match pairs | [head .. tail] -> {head#_0: head#_1} ++ tail.to_map() \ [] -> {} ``` Also removed the `None` type, as it was really just an `Option`. This creates some issues when comparing an option with `None`. I need to see what can be done to fix that small bug. --- optd-cli/examples/generics.opt | 91 ++-- optd-cli/examples/tutorial.opt | 46 +- optd/src/dsl/analyzer/errors.rs | 22 +- optd/src/dsl/analyzer/from_ast/converter.rs | 34 +- optd/src/dsl/analyzer/from_ast/expr.rs | 34 +- optd/src/dsl/analyzer/from_ast/pattern.rs | 4 +- optd/src/dsl/analyzer/from_ast/types.rs | 18 +- .../src/dsl/analyzer/type_checks/converter.rs | 25 +- optd/src/dsl/analyzer/type_checks/generate.rs | 7 +- optd/src/dsl/analyzer/type_checks/glb.rs | 144 ++++-- optd/src/dsl/analyzer/type_checks/lub.rs | 170 +++++-- optd/src/dsl/analyzer/type_checks/registry.rs | 179 +++---- optd/src/dsl/analyzer/type_checks/solver.rs | 444 +++++++++++++++++- optd/src/dsl/analyzer/type_checks/subtype.rs | 148 ++++-- optd/src/dsl/parser/ast.rs | 2 +- optd/src/dsl/parser/function.rs | 248 +++++++++- 16 files changed, 1245 insertions(+), 371 deletions(-) diff --git a/optd-cli/examples/generics.opt b/optd-cli/examples/generics.opt index 004a7687..8cc18771 100644 --- a/optd-cli/examples/generics.opt +++ b/optd-cli/examples/generics.opt @@ -3,62 +3,43 @@ data Physical data LogicalProperties data PhysicalProperties -// Generic functions with recursive closures -fn memoize(f: T -> U): T -> U = - let - cache = [] // Simplified cache - in - (x: T) -> { - // In a real implementation, we would check the cache - // and only compute if needed - f(x) - } +// Convert array of key-value pairs to a map with EqHash constraint +fn (pairs: [(K, V)]) to_map(): {K: V} = match pairs +| [head .. tail] -> {head#_0: head#_1} ++ tail.to_map() +\ [] -> {} -fn (array: [T]) quicksort(cmp: (T, T) -> I64) = - match array - | [] -> [] - \ [pivot .. rest] -> { - let - partition = (arr: [T], pivot: T, cmp: (T, T) -> I64) -> { - let - less = arr.filter(x: T -> cmp(x, pivot) < 0), - greater = arr.filter(x: T -> cmp(x, pivot) >= 0) - in - (less, greater) - }, +// Simple filter function for arrays +fn (array: [T]) filter(pred: T -> Bool): [T] = match array +| [] -> [] +\ [x .. xs] -> + if pred(x) then [x] ++ xs.filter(pred) + else xs.filter(pred) - result = partition(rest, pivot, cmp), - less = result#_0, - greater = result#_1, - sorted_less = less.quicksort(cmp), - sorted_greater = greater.quicksort(cmp) - in - sorted_less ++ [pivot] ++ sorted_greater - } +// Simple map function +fn (array: [T]) map(f: T -> U): [U] = match array +| [] -> [] +\ [x .. xs] -> [f(x)] ++ xs.map(f) -fn (array: [T]) filter(predicate: T -> Bool) = match array - | [] -> [] - \ [x .. xs] -> - if predicate(x) then - [x] ++ xs.filter(predicate) - else - xs.filter(predicate) +// Data type for our test +data User(name: String, age: I64) -// Fixed recursive function implementation -fn factorial(n: I64): I64 = - if n <= 1 then 1 else n * factorial(n - 1) - -fn main(): I64 = - let - numbers = [5, 3, 8, 1, 2, 9, 4, 7, 6], - compare = (a: I64, b: I64) -> a - b, - - // Use regular factorial function instead of the Y-combinator approach - memoized_factorial = memoize(factorial), - - // Sort numbers and filter - sorted = numbers.quicksort(compare), - filtered = sorted.filter(x: I64 -> x > 2), - result = memoized_factorial(7) - in - result +fn main(): {String: I64} = +let + // Create some users + users = [ + User("Alice", 25), + User("Bob", 17), + User("Charlie", 30), + User("Diana", 15) + ], + + // Filter for adults + adults = users.filter((u: User) -> u#age >= 18), + + // Create name -> age map + name_age_pairs = adults.map((u: User) -> (u#name, u#age)), + + // Convert to map + age_map = name_age_pairs.to_map() +in + age_map \ No newline at end of file diff --git a/optd-cli/examples/tutorial.opt b/optd-cli/examples/tutorial.opt index 957b94c1..14bca04b 100644 --- a/optd-cli/examples/tutorial.opt +++ b/optd-cli/examples/tutorial.opt @@ -16,7 +16,7 @@ // You can compile this tutorial file with: // -// cargo run -- compile tutorial.opt +// cargo run --bin optd-cli -- compile [path] --mock-udfs map get_table_schema properties statistics optimize // ------------------------- // 1. Logical Operators @@ -52,7 +52,7 @@ data Scalar = | IntLiteral(value: I64) | StringLiteral(value: String) \ BoolLiteral(value: Bool) - | Arithmetic = + | Arith = | Mult(left: Scalar, right: Scalar) | Add(left: Scalar, right: Scalar) | Sub(left: Scalar, right: Scalar) @@ -65,7 +65,7 @@ data Scalar = | Function = | Cast(expr: Scalar, target_type: String) | Substring(str: Scalar, start: Scalar, length: Scalar) - \ Concat(args: [Scalar]) + \ Concatenate(args: [Scalar]) \ AggregateExpr = | Sum(expr: Scalar) | Count(expr: Scalar) @@ -213,7 +213,7 @@ fn (list: [E]) map(f: (E) -> F): [F] = match list | [head .. tail] -> [f(head)] ++ tail.map(f) \ [] -> [] -fn (pairs: [(K, V)]) to_map(): {K : V} = match pairs +fn (pairs: [(K, V)]) to_map(): {K : V} = match pairs | [head .. tail] -> {head#_0 : head#_1} ++ tail.to_map() \ [] -> {} @@ -238,33 +238,33 @@ fn (costed: Physical$) statistics(): CostedProperties // These are annotated with [transformation] and convert between logical plans: // Helper function for scalar rewrites. -fn (expr: Scalar) remap(map: {I64 : I64}): Scalar = +fn (expr: Scalar) remap(bindings: {I64 : I64}): Scalar = match expr | ColumnRef(idx) -> - if map(idx) != none then - ColumnRef(map(idx)) + if bindings(idx) != none then // TODO: Fix none bug... Should not be a type! Rather Option. + ColumnRef(0) // TODO: Add ! once we have the `!` syntax. else ColumnRef(idx) | IntLiteral(value) -> IntLiteral(value) | StringLiteral(value) -> StringLiteral(value) | BoolLiteral(value) -> BoolLiteral(value) - | Mult(left, right) -> Mult(left.remap(map), right.remap(map)) - | Add(left, right) -> Add(left.remap(map), right.remap(map)) - | Sub(left, right) -> Sub(left.remap(map), right.remap(map)) - | Div(left, right) -> Div(left.remap(map), right.remap(map)) - | And(children) -> And(children.map(child -> child.remap(map))) - | Or(children) -> Or(children.map(child -> child.remap(map))) - | Not(child) -> Not(child.remap(map)) - | Equals(left, right) -> Equals(left.remap(map), right.remap(map)) - | Cast(expr, target_type) -> Cast(expr.remap(map), target_type) + | Mult(left, right) -> Mult(left.remap(bindings), right.remap(bindings)) + | Add(left, right) -> Add(left.remap(bindings), right.remap(bindings)) + | Sub(left, right) -> Sub(left.remap(bindings), right.remap(bindings)) + | Div(left, right) -> Div(left.remap(bindings), right.remap(bindings)) + | And(children) -> And(children.map(child: Scalar -> child.remap(bindings))) + | Or(children) -> Or(children.map(child: Scalar -> child.remap(bindings))) + | Not(child) -> Not(child.remap(bindings)) + | Equals(left, right) -> Equals(left.remap(bindings), right.remap(bindings)) + | Cast(expr, target_type) -> Cast(expr.remap(bindings), target_type) | Substring(str, start, length) -> - Substring(str.remap(map), start.remap(map), length.remap(map)) - | Concat(args) -> Concat(args.map(arg -> arg.remap(map))) - | Sum(expr) -> Sum(expr.remap(map)) - | Count(expr) -> Count(expr.remap(map)) - | Min(expr) -> Min(expr.remap(map)) - | Max(expr) -> Max(expr.remap(map)) - \ Avg(expr) -> Avg(expr.remap(map)) + Substring(str.remap(bindings), start.remap(bindings), length.remap(bindings)) + | Concatenate(args) -> Concatenate(args.map(arg: Scalar -> arg.remap(bindings))) + | Sum(expr) -> Sum(expr.remap(bindings)) + | Count(expr) -> Count(expr.remap(bindings)) + | Min(expr) -> Min(expr.remap(bindings)) + | Max(expr) -> Max(expr.remap(bindings)) + \ Avg(expr) -> Avg(expr.remap(bindings)) [transformation] fn (expr: Logical*) join_commute(): Logical? = match expr diff --git a/optd/src/dsl/analyzer/errors.rs b/optd/src/dsl/analyzer/errors.rs index eeb05999..5554c2a1 100644 --- a/optd/src/dsl/analyzer/errors.rs +++ b/optd/src/dsl/analyzer/errors.rs @@ -147,6 +147,11 @@ pub enum AnalyzerErrorKind { // To be able to call display function of Type unknowns: HashMap, }, + + ReservedType { + name: String, + span: Span, + }, } impl AnalyzerErrorKind { @@ -339,7 +344,6 @@ impl AnalyzerErrorKind { .into() } - // New constructor for array decomposition errors pub fn new_invalid_array_decomposition( scrutinee_span: &Span, pattern_span: &Span, @@ -354,6 +358,14 @@ impl AnalyzerErrorKind { } .into() } + + pub fn new_reserved_type(name: &str, span: &Span) -> Box { + Self::ReservedType { + name: name.to_string(), + span: span.clone(), + } + .into() + } } impl Diagnose for Box { @@ -504,13 +516,18 @@ impl Diagnose for Box { ), ) }, - // Handler for the new InvalidArrayDecomposition error InvalidArrayDecomposition { scrutinee_span, pattern_span, scrutinee_type, unknowns, } => self.build_array_decomp_error_report(scrutinee_span, pattern_span, scrutinee_type, unknowns), + ReservedType { name, span } => self.build_single_span_report( + span, + &format!("Reserved type name: '{}'", name), + &format!("'{}' is a reserved type name", name), + "Choose a different name for your type. Reserved type names are used internally by the system", + ), } } @@ -537,6 +554,7 @@ impl Diagnose for Box { InvalidTransformation { span, .. } => span, InvalidImplementation { span, .. } => span, InvalidArrayDecomposition { pattern_span, .. } => pattern_span, // Use pattern span as primary + ReservedType { span, .. } => span, // New case for ReservedType }; (span.src_file.clone(), Source::from(self.src_code.clone())) diff --git a/optd/src/dsl/analyzer/from_ast/converter.rs b/optd/src/dsl/analyzer/from_ast/converter.rs index 149612b9..0cecf220 100644 --- a/optd/src/dsl/analyzer/from_ast/converter.rs +++ b/optd/src/dsl/analyzer/from_ast/converter.rs @@ -3,7 +3,7 @@ use crate::dsl::analyzer::hir::context::Context; use crate::dsl::analyzer::hir::{Annotation, FunKind, Identifier}; use crate::dsl::analyzer::hir::{CoreData, TypedSpan, Udf, Value}; use crate::dsl::analyzer::type_checks::converter::create_function_type; -use crate::dsl::analyzer::type_checks::registry::{Type, TypeRegistry}; +use crate::dsl::analyzer::type_checks::registry::{Generic, Type, TypeRegistry}; use crate::dsl::parser::ast::Function; use crate::dsl::utils::span::Spanned; use FunKind::*; @@ -53,7 +53,7 @@ impl ASTConverter { let generics = { let mut generics_map = HashMap::new(); - for param in &func.type_params { + for (param, bound) in &func.type_params { let param_name = &*param.value; // Check for duplicates. @@ -65,16 +65,25 @@ impl ASTConverter { )); } + // Convert and check bound. + let bound_ty = bound + .clone() + .map(|b| self.convert_type(&b, &HashMap::new(), true)) + .transpose()?; + // Assign ID and store. let id = self.registry.next_id; self.registry.next_id += 1; - generics_map.insert(param_name.clone(), (id, param.span.clone())); + generics_map.insert( + param_name.clone(), + (Generic(id, bound_ty), param.span.clone()), + ); } // Extract the final mapping of param names to IDs. generics_map .into_iter() - .map(|(name, (id, _))| (name, id)) + .map(|(name, (generic, _))| (name, generic)) .collect() }; @@ -133,7 +142,7 @@ impl ASTConverter { fn get_parameters( &mut self, func: &Function, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { // Start with receiver if it exists. let mut param_fields = match &func.receiver { @@ -171,7 +180,7 @@ mod converter_tests { use crate::catalog::Catalog; use crate::dsl::analyzer::from_ast::from_ast; use crate::dsl::analyzer::hir::{CoreData, FunKind}; - use crate::dsl::analyzer::type_checks::registry::TypeKind; + use crate::dsl::analyzer::type_checks::registry::{Generic, TypeKind}; use crate::dsl::parser::ast::{self, Adt, Function, Item, Module, Type as AstType}; use crate::dsl::utils::span::{Span, Spanned}; @@ -437,7 +446,10 @@ mod converter_tests { let mut func_val = (*func.value).clone(); // Add type parameters - func_val.type_params = vec![spanned(String::from("T")), spanned(String::from("U"))]; + func_val.type_params = vec![ + (spanned(String::from("T")), None), + (spanned(String::from("U")), None), + ]; // Modify the return type to use a generic func_val.return_type = spanned(AstType::Identifier(String::from("T"))); @@ -460,7 +472,7 @@ mod converter_tests { match &*func_val.unwrap().metadata.ty.value { TypeKind::Closure(_, ret_type) => { match &*ret_type.value { - TypeKind::Generic(id) => { + TypeKind::Gen(Generic(id, _)) => { // We expect id to be 0 since "T" should be the first generic parameter assert_eq!(*id, 0); } @@ -479,9 +491,9 @@ mod converter_tests { // Add type parameters with a duplicate func_val.type_params = vec![ - spanned(String::from("T")), - spanned(String::from("U")), - spanned(String::from("T")), // Duplicate of "T" + (spanned(String::from("T")), None), + (spanned(String::from("U")), None), + (spanned(String::from("T")), None), // Duplicate of "T" ]; let func = spanned(func_val); diff --git a/optd/src/dsl/analyzer/from_ast/expr.rs b/optd/src/dsl/analyzer/from_ast/expr.rs index 61a18178..cdccb96c 100644 --- a/optd/src/dsl/analyzer/from_ast/expr.rs +++ b/optd/src/dsl/analyzer/from_ast/expr.rs @@ -9,7 +9,7 @@ use crate::dsl::analyzer::hir::{ BinOp, CoreData, Expr, ExprKind, FunKind, Identifier, LetBinding, Literal, TypedSpan, UnaryOp, }; use crate::dsl::analyzer::type_checks::converter::create_function_type; -use crate::dsl::analyzer::type_checks::registry::{Type, TypeKind}; +use crate::dsl::analyzer::type_checks::registry::{Generic, Type, TypeKind}; use crate::dsl::parser::ast::{ self, BinOp as AstBinOp, Expr as AstExpr, Literal as AstLiteral, PostfixOp, }; @@ -26,7 +26,7 @@ impl ASTConverter { pub(super) fn convert_expr( &mut self, spanned_expr: &Spanned, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { use TypeKind::*; @@ -107,7 +107,7 @@ impl ASTConverter { self.convert_fail(error_expr, generics)? } AstExpr::None => { - ty = None.into(); + ty = Optional(Nothing.into()).into(); CoreExpr(CoreData::None) } AstExpr::Block(block) => self.convert_block(block, generics)?, @@ -142,7 +142,7 @@ impl ASTConverter { op: &AstBinOp, right: &Spanned, span: &Span, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { use BinOp::*; @@ -237,7 +237,7 @@ impl ASTConverter { &mut self, op: &ast::UnaryOp, operand: &Spanned, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let hir_operand = self.convert_expr(operand, generics)?; @@ -254,7 +254,7 @@ impl ASTConverter { field: &Spanned, init: &Spanned, body: &Spanned, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let hir_init = self.convert_expr(init, generics)?; let hir_body = self.convert_expr(body, generics)?; @@ -271,7 +271,7 @@ impl ASTConverter { &mut self, scrutinee: &Spanned, arms: &[Spanned], - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let hir_scrutinee = self.convert_expr(scrutinee, generics)?; let hir_arms = self.convert_match_arms(arms, generics)?; @@ -284,7 +284,7 @@ impl ASTConverter { condition: &Spanned, then_branch: &Spanned, else_branch: &Spanned, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let hir_condition = self.convert_expr(condition, generics)?; let hir_then = self.convert_expr(then_branch, generics)?; @@ -300,7 +300,7 @@ impl ASTConverter { fn convert_expr_list( &mut self, elements: &[Spanned], - generics: &HashMap, + generics: &HashMap, ) -> Result>>, Box> { let mut hir_elements = Vec::with_capacity(elements.len()); @@ -315,7 +315,7 @@ impl ASTConverter { fn convert_array( &mut self, elements: &[Spanned], - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let hir_elements = self.convert_expr_list(elements, generics)?; Ok(CoreExpr(CoreData::Array(hir_elements))) @@ -324,7 +324,7 @@ impl ASTConverter { fn convert_tuple( &mut self, elements: &[Spanned], - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let hir_elements = self.convert_expr_list(elements, generics)?; Ok(CoreExpr(CoreData::Tuple(hir_elements))) @@ -333,7 +333,7 @@ impl ASTConverter { fn convert_map( &mut self, entries: &[(Spanned, Spanned)], - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let mut hir_entries = Vec::with_capacity(entries.len()); @@ -382,7 +382,7 @@ impl ASTConverter { name: &Spanned, args: &[Spanned], span: &Span, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { self.validate_constructor(name, span, args.len())?; @@ -394,7 +394,7 @@ impl ASTConverter { &mut self, params: &[(Identifier, Type)], body: &Spanned, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let param_names = params.iter().map(|(name, _)| name.clone()).collect(); let hir_body = self.convert_expr(body, generics)?; @@ -409,7 +409,7 @@ impl ASTConverter { &mut self, expr: &Spanned, op: &PostfixOp, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let hir_expr = self.convert_expr(expr, generics)?; @@ -443,7 +443,7 @@ impl ASTConverter { fn convert_fail( &mut self, error_expr: &Spanned, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let hir_error = self.convert_expr(error_expr, generics)?; @@ -453,7 +453,7 @@ impl ASTConverter { fn convert_block( &mut self, block: &Spanned, - generics: &HashMap, + generics: &HashMap, ) -> Result, Box> { let hir_expr = self.convert_expr(block, generics)?; diff --git a/optd/src/dsl/analyzer/from_ast/pattern.rs b/optd/src/dsl/analyzer/from_ast/pattern.rs index bd015cc4..40984a22 100644 --- a/optd/src/dsl/analyzer/from_ast/pattern.rs +++ b/optd/src/dsl/analyzer/from_ast/pattern.rs @@ -6,7 +6,7 @@ use super::converter::ASTConverter; use crate::dsl::analyzer::errors::AnalyzerErrorKind; use crate::dsl::analyzer::hir::{Identifier, MatchArm, Pattern, PatternKind, TypedSpan}; -use crate::dsl::analyzer::type_checks::registry::TypeKind; +use crate::dsl::analyzer::type_checks::registry::{Generic, TypeKind}; use crate::dsl::parser::ast::{self, Pattern as AstPattern}; use crate::dsl::utils::span::Spanned; use std::collections::HashMap; @@ -20,7 +20,7 @@ impl ASTConverter { pub(super) fn convert_match_arms( &mut self, arms: &[Spanned], - generics: &HashMap, + generics: &HashMap, ) -> Result>, Box> { arms.iter() .map(|arm| { diff --git a/optd/src/dsl/analyzer/from_ast/types.rs b/optd/src/dsl/analyzer/from_ast/types.rs index 5c6c040e..9c6f07cd 100644 --- a/optd/src/dsl/analyzer/from_ast/types.rs +++ b/optd/src/dsl/analyzer/from_ast/types.rs @@ -6,7 +6,7 @@ use super::converter::ASTConverter; use crate::dsl::analyzer::errors::AnalyzerErrorKind; use crate::dsl::analyzer::hir::Identifier; -use crate::dsl::analyzer::type_checks::registry::{Type, TypeKind}; +use crate::dsl::analyzer::type_checks::registry::{Generic, RESERVED_TYPE_MAP, Type, TypeKind}; use crate::dsl::parser::ast::Type as AstType; use crate::dsl::utils::span::Spanned; use std::collections::HashMap; @@ -35,7 +35,7 @@ impl ASTConverter { pub(super) fn convert_type( &mut self, ast_type: &Spanned, - generics: &HashMap, + generics: &HashMap, ascending: bool, ) -> Result> { use TypeKind::*; @@ -74,8 +74,10 @@ impl ASTConverter { Costed(self.convert_type(inner_type, generics, ascending)?) } AstType::Identifier(name) => { - if let Some(id) = generics.get(name) { - Generic(*id) + if let Some(generic) = generics.get(name) { + Gen(generic.clone()) + } else if let Some(type_kind) = RESERVED_TYPE_MAP.get(name) { + type_kind.clone() } else { if !self.registry.subtypes.contains_key(name) { return Err(AnalyzerErrorKind::new_undefined_type(name, &ast_type.span)); @@ -277,12 +279,12 @@ mod types_tests { let generic_type = AstType::Identifier("T".to_string()); let mut generics = HashMap::new(); // Assign ID 42 to the generic "T" - generics.insert("T".to_string(), 42); + generics.insert("T".to_string(), Generic(42, None)); let result = converter .convert_type(&spanned(generic_type), &generics, true) .expect("Generic type conversion should succeed"); match &*result.value { - TypeKind::Generic(id) => assert_eq!(*id, 42), + TypeKind::Gen(generic) => assert_eq!(generic.0, 42), _ => panic!("Expected Generic type"), } } @@ -450,8 +452,8 @@ mod types_tests { // Setup generics with numeric IDs let mut generics = HashMap::new(); - generics.insert("T".to_string(), 0); - generics.insert("U".to_string(), 1); + generics.insert("T".to_string(), Generic(0, None)); + generics.insert("U".to_string(), Generic(1, None)); // Test generic types (should pass validation because they're in the generics map) let generic_type = AstType::Identifier("T".to_string()); diff --git a/optd/src/dsl/analyzer/type_checks/converter.rs b/optd/src/dsl/analyzer/type_checks/converter.rs index 2ecc1cd9..b58bddc1 100644 --- a/optd/src/dsl/analyzer/type_checks/converter.rs +++ b/optd/src/dsl/analyzer/type_checks/converter.rs @@ -1,4 +1,5 @@ use super::registry::{Type, TypeKind}; +use crate::dsl::analyzer::type_checks::registry::{Generic, RESERVED_TYPE_MAP}; use crate::dsl::parser::ast::Type as AstType; use crate::dsl::utils::span::{OptionalSpanned, Spanned}; use std::collections::HashMap; @@ -20,7 +21,13 @@ pub(crate) fn convert_ast_type(ast_ty: Spanned) -> Type { let span = ast_ty.span; let kind = match *ast_ty.value { - AstType::Identifier(name) => Adt(name), + AstType::Identifier(name) => { + if let Some(type_kind) = RESERVED_TYPE_MAP.get(&name) { + type_kind.clone() + } else { + Adt(name.clone()) + } + } AstType::Int64 => I64, AstType::String => String, AstType::Bool => Bool, @@ -99,25 +106,33 @@ pub(crate) fn type_display(ty: &Type, resolved_unknown: &HashMap) - Unit => "()".to_string(), Universe => "Universe".to_string(), Nothing => "Nothing".to_string(), - None => "None".to_string(), // Unknown types UnknownAsc(id) => { format!( - "≧{{{}}}", + "≧`{}`", type_display(resolved_unknown.get(id).unwrap(), resolved_unknown) ) } UnknownDesc(id) => { format!( - "≦{{{}}}", + "≦`{}`", type_display(resolved_unknown.get(id).unwrap(), resolved_unknown) ) } // User types Adt(name) => name.to_string(), - Generic(name) => format!("Gen<#{}>", name), + Gen(Generic(id, bound)) => match bound { + Some(bound_type) => { + format!( + "Gen<#{}: {}>", + id, + type_display(bound_type, resolved_unknown) + ) + } + Option::None => format!("Gen<#{}>", id), + }, // Composite types Array(elem) => format!("[{}]", type_display(elem, resolved_unknown)), diff --git a/optd/src/dsl/analyzer/type_checks/generate.rs b/optd/src/dsl/analyzer/type_checks/generate.rs index b1693b1d..bbd3f2fb 100644 --- a/optd/src/dsl/analyzer/type_checks/generate.rs +++ b/optd/src/dsl/analyzer/type_checks/generate.rs @@ -553,7 +553,10 @@ mod scope_check_tests { let fun_val = Value { data: CoreData::Function(FunKind::Closure(params, Arc::new(body))), metadata: TypedSpan { - ty: create_function_type(&vec![TypeKind::None.into(); len], &TypeKind::None.into()), + ty: create_function_type( + &vec![TypeKind::Nothing.into(); len], + &TypeKind::Nothing.into(), + ), span: create_test_span(), }, }; @@ -757,7 +760,7 @@ mod scope_check_tests { create_test_span(), )), ))), - TypeKind::Closure(TypeKind::None.into(), TypeKind::None.into()).into(), + TypeKind::Closure(TypeKind::Nothing.into(), TypeKind::Nothing.into()).into(), create_test_span(), ); diff --git a/optd/src/dsl/analyzer/type_checks/glb.rs b/optd/src/dsl/analyzer/type_checks/glb.rs index 96f867fa..1350b08b 100644 --- a/optd/src/dsl/analyzer/type_checks/glb.rs +++ b/optd/src/dsl/analyzer/type_checks/glb.rs @@ -12,19 +12,21 @@ impl TypeRegistry { /// contravariance for return types. /// 4. For ADTs, it finds the more specific of the two types if one is a subtype of the other, given /// the absence of multiple inheritance within ADTs. - /// 5. For wrapper types: + /// 5. For generic types: + /// - If either type is generic, check if one is a subtype of the other and return the subtype. + /// - If no subtype relationship exists, return Nothing. + /// 6. For wrapper types: /// - Optional(T) and Optional(U) yields Optional(GLB(T, U)). /// - Optional(T) and U yields GLB(T, U). - /// - None is the GLB of None and any Optional type. /// - For Stored/Costed: Costed is more specific than Stored, so mixed operations yield Costed. - /// 6. For compatibility between collections and functions: + /// 7. For compatibility between collections and functions: /// - Map and Function: returns the more specific type if one is a subtype of the other, /// or computes a common Function type if no direct subtyping relationship exists. /// - Array and Function: returns the more specific type if one is a subtype of the other, /// or computes a common Function type if no direct subtyping relationship exists. /// Only applies when the Function parameter type is I64. - /// 7. Nothing is the bottom type, Universe is the top type, so GLB(Universe, T) = T and GLB(Nothing, T) = Nothing. - /// 8. If no common subtype exists, returns the Nothing type. + /// 8. Nothing is the bottom type, Universe is the top type, so GLB(Universe, T) = T and GLB(Nothing, T) = Nothing. + /// 9. If no common subtype exists, returns the Nothing type. /// /// # Arguments /// @@ -55,9 +57,15 @@ impl TypeRegistry { return self.greatest_lower_bound(type1, &bound_unknown, has_changed); } - // Handle generics. - (Generic(id1), Generic(id2)) if id1 == id2 => { - return type1.clone(); + // Generic types: just check if one is a subtype of the other. + (Gen(_), _) | (_, Gen(_)) => { + if self.is_subtype_infer(type1, type2, has_changed) { + return type1.clone(); + } else if self.is_subtype_infer(type2, type1, has_changed) { + return type2.clone(); + } else { + return Nothing.into(); + } } // Nothing is the bottom type - GLB with anything is Nothing. @@ -68,12 +76,9 @@ impl TypeRegistry { (other, Universe) => other.clone(), // Primitive types - check for equality. - (I64, I64) - | (String, String) - | (F64, F64) - | (Bool, Bool) - | (Unit, Unit) - | (None, None) => *type1.value.clone(), + (I64, I64) | (String, String) | (F64, F64) | (Bool, Bool) | (Unit, Unit) => { + *type1.value.clone() + } // Array covariance: GLB(Array, Array) = Array. (Array(elem1), Array(elem2)) => { @@ -102,7 +107,6 @@ impl TypeRegistry { let glb_inner = self.greatest_lower_bound(inner1, inner2, has_changed); Optional(glb_inner) } - (None, Optional(_)) | (Optional(_), None) => None, (Optional(inner), _) => { return self.greatest_lower_bound(inner, type2, has_changed); } @@ -237,7 +241,7 @@ impl TypeRegistry { #[cfg(test)] mod tests { use super::*; - use crate::dsl::analyzer::type_checks::lub::tests::setup_type_hierarchy; + use crate::dsl::analyzer::type_checks::{lub::tests::setup_type_hierarchy, registry::Generic}; use TypeKind::*; /// Helper function to simplify GLB assertions @@ -542,17 +546,6 @@ mod tests { Optional(Nothing.into()), ); - // None and Optional - assert_glb_eq( - &mut reg, - &None.into(), - &Optional(Adt("Dog".to_string()).into()).into(), - None, - ); - - // None and None - assert_glb_eq(&mut reg, &None.into(), &None.into(), None); - // Optional and base type assert_glb_eq( &mut reg, @@ -719,4 +712,101 @@ mod tests { ), ); } + + #[test] + fn test_greatest_lower_bound_generics() { + let mut reg = setup_type_hierarchy(); + + // Set up ADT types for testing + let dog_type: Type = Adt("Dog".to_string()).into(); + let cat_type = Adt("Cat".to_string()).into(); + let mammals_type: Type = Adt("Mammals".to_string()).into(); + let animals_type = Adt("Animals".to_string()).into(); + + // Create generic types with ADT bounds + let generic_1_mammals = Gen(Generic(1, Some(mammals_type.clone()))).into(); + let generic_2_dog = Gen(Generic(2, Some(dog_type.clone()))).into(); + + // Test case 1: Same generic ID is its own GLB + assert_glb_eq( + &mut reg, + &generic_1_mammals, + &generic_1_mammals, + Gen(Generic(1, Some(mammals_type.clone()))), + ); + + // Test case 2: Dog <: Mammals, so Dog is more specific + assert_glb_eq( + &mut reg, + &dog_type, + &generic_1_mammals, + Adt("Dog".to_string()), + ); + + // Test case 3: Generic with Dog bound is more specific than Generic with Mammals bound + assert_glb_eq( + &mut reg, + &generic_1_mammals, + &generic_2_dog, + Gen(Generic(2, Some(dog_type.clone()))), + ); + + // Test case 4: No relationship between Cat and generic with Dog bound + assert_glb_nothing(&mut reg, &cat_type, &generic_2_dog); + + // Test case 5: Animals >: Mammals, so generic with Mammals bound is more specific + assert_glb_eq( + &mut reg, + &animals_type, + &generic_1_mammals, + Gen(Generic(1, Some(mammals_type.clone()))), + ); + + // Test case 6: Generic with no bound + let generic_unbounded = Gen(Generic(10, Option::None)).into(); + + // Test unbounded generic with another type - just use subtype checks + assert_glb_nothing(&mut reg, &generic_unbounded, &dog_type); + + // Test unbounded generic with itself + assert_glb_eq( + &mut reg, + &generic_unbounded, + &generic_unbounded, + Gen(Generic(10, Option::None)), + ); + + // Test case 7: Container types with generics + let array_generic_mammals = Array(generic_1_mammals.clone()).into(); + let array_dog = Array(dog_type.clone()).into(); + + assert_glb_eq( + &mut reg, + &array_generic_mammals, + &array_dog, + Array(dog_type.clone()), + ); + + // Test case 8: Complex types with generics + let map_generic_mammals_dog = Map(generic_1_mammals.clone(), dog_type.clone()).into(); + let map_animals_dog = Map(animals_type.clone(), dog_type.clone()).into(); + + assert_glb_eq( + &mut reg, + &map_generic_mammals_dog, + &map_animals_dog, + Map(animals_type.clone(), dog_type.clone()), + ); + + // Test case 9: Optional with generics + let optional_generic_mammals = Optional(generic_1_mammals.clone()).into(); + let optional_dog = Optional(dog_type.clone()).into(); + + assert_glb_eq( + &mut reg, + &optional_generic_mammals, + &optional_dog, + Optional(dog_type.clone()), + ); + } } diff --git a/optd/src/dsl/analyzer/type_checks/lub.rs b/optd/src/dsl/analyzer/type_checks/lub.rs index 61d86b47..773f3246 100644 --- a/optd/src/dsl/analyzer/type_checks/lub.rs +++ b/optd/src/dsl/analyzer/type_checks/lub.rs @@ -10,19 +10,20 @@ impl TypeRegistry { /// 1. Primitive types check for equality. /// 2. For container types (Array, Tuple, etc.), it applies covariance rules. /// 3. For function (and Map) types, it uses contravariance for parameters and covariance for return types. - /// 4. For native trait types (Concat, EqHash, Arithmetic), the result is the trait only if both types + /// 4. For generic types: + /// - If either type is generic, check if one is a supertype of the other and return the supertype. + /// - If no supertype relationship exists between the generics, return Universe. + /// 5. For native trait types (Concat, EqHash, Arithmetic), the result is the trait only if both types /// implement it, and at least one of the types is the trait itself. - /// 5. For ADT types, it finds the closest common supertype in the type hierarchy. - /// 6. For wrapper types: + /// 6. For ADT types, it finds the closest common supertype in the type hierarchy. + /// 7. For wrapper types: /// - Optional preserves the wrapper and computes LUB of inner types - /// - None and Optional(T) yields Optional(T) - /// - None and non-Optional T yields Optional(T) /// - For Stored/Costed: Costed is considered more specific than Stored, and /// either wrapper can be removed when comparing with non-wrapped types - /// 7. Map types can be viewed as functions from keys to optional values, and + /// 8. Map types can be viewed as functions from keys to optional values, and /// Arrays as functions from indices to values, with appropriate type conversions. - /// 8. Nothing is the bottom type, Universe it the top type; LUB(Nothing, T) = T and LUB(Universe, T) = Universe. - /// 9. If no meaningful upper bound exists, Universe is returned. + /// 9. Nothing is the bottom type, Universe it the top type; LUB(Nothing, T) = T and LUB(Universe, T) = Universe. + /// 10. If no meaningful upper bound exists, Universe is returned. /// /// # Arguments /// @@ -53,9 +54,15 @@ impl TypeRegistry { return self.least_upper_bound(type1, &bound_unknown, has_changed); } - // Handle generics. - (Generic(id1), Generic(id2)) if id1 == id2 => { - return type1.clone(); + // Generic types: just check if one is a supertype of the other. + (Gen(_), _) | (_, Gen(_)) => { + if self.is_subtype_infer(type1, type2, has_changed) { + return type2.clone(); + } else if self.is_subtype_infer(type2, type1, has_changed) { + return type1.clone(); + } else { + return Universe.into(); + } } // Universe is the top type - LUB(Universe, T) = Universe. @@ -66,12 +73,9 @@ impl TypeRegistry { (other, Nothing) => other.clone(), // Primitive types - check for equality. - (I64, I64) - | (String, String) - | (F64, F64) - | (Bool, Bool) - | (Unit, Unit) - | (None, None) => *type1.value.clone(), + (I64, I64) | (String, String) | (F64, F64) | (Bool, Bool) | (Unit, Unit) => { + *type1.value.clone() + } // Array covariance: LUB(Array, Array) = Array. (Array(elem1), Array(elem2)) => { @@ -100,11 +104,8 @@ impl TypeRegistry { let lub_inner = self.least_upper_bound(inner1, inner2, has_changed); Optional(lub_inner) } - (None, Optional(inner)) | (Optional(inner), None) => Optional(inner.clone()), (Optional(inner), _) => Optional(self.least_upper_bound(inner, type2, has_changed)), (_, Optional(inner)) => Optional(self.least_upper_bound(type1, inner, has_changed)), - (None, _) => Optional(type2.clone()), - (_, None) => Optional(type1.clone()), // Stored type handling. (Stored(inner1), Stored(inner2)) => { @@ -288,8 +289,9 @@ impl TypeRegistry { pub mod tests { use super::*; use crate::dsl::{ - analyzer::type_checks::registry::type_registry_tests::{ - create_product_adt, create_sum_adt, + analyzer::type_checks::registry::{ + Generic, + type_registry_tests::{create_product_adt, create_sum_adt}, }, parser::ast::Type as AstType, }; @@ -768,22 +770,6 @@ pub mod tests { Optional(Adt("Animals".to_string()).into()), ); - // None and Optional ADT - assert_lub_eq( - &mut reg, - &None.into(), - &Optional(Adt("Dog".to_string()).into()).into(), - Optional(Adt("Dog".to_string()).into()), - ); - - // None and ADT type - assert_lub_eq( - &mut reg, - &None.into(), - &Adt("Dog".to_string()).into(), - Optional(Adt("Dog".to_string()).into()), - ); - // ADT and Optional related ADT assert_lub_eq( &mut reg, @@ -981,4 +967,112 @@ pub mod tests { ), ); } + + #[test] + fn test_generic_lub() { + let mut reg = setup_type_hierarchy(); + + // Set up ADT types for testing + let dog_type: Type = Adt("Dog".to_string()).into(); + let cat_type = Adt("Cat".to_string()).into(); + let mammals_type: Type = Adt("Mammals".to_string()).into(); + let animals_type = Adt("Animals".to_string()).into(); + + // Create generic types with ADT bounds + let generic_1_mammals = Gen(Generic(1, Some(mammals_type.clone()))).into(); + let generic_2_dog = Gen(Generic(2, Some(dog_type.clone()))).into(); + let generic_3_mammals = Gen(Generic(3, Some(mammals_type.clone()))).into(); + + // Test case 1: Same generic ID is its own LUB + assert_lub_eq( + &mut reg, + &generic_1_mammals, + &generic_1_mammals, + Gen(Generic(1, Some(mammals_type.clone()))), + ); + + // Test case 2: Dog <: Mammals, so Mammals is the LUB + assert_lub_eq( + &mut reg, + &dog_type, + &generic_1_mammals, + Gen(Generic(1, Some(mammals_type.clone()))), + ); + + // Test case 3: Generic with Mammals bound is more general than Generic with Dog bound + assert_lub_eq( + &mut reg, + &generic_1_mammals, + &generic_2_dog, + Gen(Generic(1, Some(mammals_type.clone()))), + ); + + // Test case 4: Different generics with same bounds + // This depends on is_subtype_infer implementation + let mut has_changed = false; + let result = + reg.least_upper_bound(&generic_1_mammals, &generic_3_mammals, &mut has_changed); + + // Since these have the same bounds but different IDs, one of them should be returned + // based on the implementation of is_subtype_infer + assert!(result == generic_1_mammals || result == generic_3_mammals); + + // Test case 5: No relationship between Cat and generic with Dog bound + assert_lub_universe(&mut reg, &cat_type, &generic_2_dog); + + // Test case 6: Animals >: Mammals, so Animals is more general + assert_lub_eq( + &mut reg, + &animals_type, + &generic_1_mammals, + *animals_type.value.clone(), + ); + + // Test case 7: Generic with no bound + let generic_unbounded = Gen(Generic(10, Option::None)).into(); + + // Test unbounded generic with another type - just use subtype checks + assert_lub_universe(&mut reg, &generic_unbounded, &dog_type); + + // Test unbounded generic with itself + assert_lub_eq( + &mut reg, + &generic_unbounded, + &generic_unbounded, + Gen(Generic(10, Option::None)), + ); + + // Test case 8: Container types with generics + let array_generic_mammals = Array(generic_1_mammals.clone()).into(); + let array_dog = Array(dog_type.clone()).into(); + + assert_lub_eq( + &mut reg, + &array_generic_mammals, + &array_dog, + Array(generic_1_mammals.clone()), + ); + + // Test case 9: Complex types with generics + let map_generic_mammals_dog = Map(generic_1_mammals.clone(), dog_type.clone()).into(); + let map_animals_dog = Map(animals_type.clone(), dog_type.clone()).into(); + + assert_lub_eq( + &mut reg, + &map_generic_mammals_dog, + &map_animals_dog, + Map(generic_1_mammals.clone(), dog_type.clone()), + ); + + // Test case 10: Optional with generics + let optional_generic_mammals = Optional(generic_1_mammals.clone()).into(); + let optional_dog = Optional(dog_type.clone()).into(); + + assert_lub_eq( + &mut reg, + &optional_generic_mammals, + &optional_dog, + Optional(generic_1_mammals.clone()), + ); + } } diff --git a/optd/src/dsl/analyzer/type_checks/registry.rs b/optd/src/dsl/analyzer/type_checks/registry.rs index 29a4cfe9..10527bd1 100644 --- a/optd/src/dsl/analyzer/type_checks/registry.rs +++ b/optd/src/dsl/analyzer/type_checks/registry.rs @@ -3,6 +3,7 @@ use crate::dsl::analyzer::hir::{Identifier, Pattern, TypedSpan}; use crate::dsl::parser::ast::{Adt, Field}; use crate::dsl::utils::span::{OptionalSpanned, Span}; use Adt::*; +use once_cell::sync::Lazy; use std::collections::BTreeMap; use std::hash::Hasher; use std::{ @@ -22,6 +23,19 @@ pub const PHYSICAL_PROPS: &str = "PhysicalProperties"; pub const CORE_TYPES: [&str; 4] = [LOGICAL_TYPE, PHYSICAL_TYPE, LOGICAL_PROPS, PHYSICAL_PROPS]; +// Reserved ADT type instances. +pub const ARITHMERIC_TYPE: &str = "Arithmetic"; +pub const CONCAT_TYPE: &str = "Concat"; +pub const EQHASH_TYPE: &str = "EqHash"; + +pub static RESERVED_TYPE_MAP: Lazy> = Lazy::new(|| { + let mut map = HashMap::new(); + map.insert(CONCAT_TYPE.to_string(), TypeKind::Concat); + map.insert(EQHASH_TYPE.to_string(), TypeKind::EqHash); + map.insert(ARITHMERIC_TYPE.to_string(), TypeKind::Arithmetic); + map +}); + /// Represents the core structure of a type without metadata. /// /// This enum contains both primitive types (like Int64, String) and complex types @@ -38,7 +52,6 @@ pub enum TypeKind { Unit, Universe, // All types are subtypes of Universe. Nothing, // Inherits all types. - None, // Inherits all optionals. // Unknown types. UnknownAsc(usize), // Strictly ascending types. @@ -46,7 +59,7 @@ pub enum TypeKind { // User types. Adt(Identifier), - Generic(usize), // Generic types, with unique id to distinguish. + Gen(Generic), // Composite types. Array(Type), @@ -65,6 +78,10 @@ pub enum TypeKind { Arithmetic, // For types that support arithmetic operations. } +/// Represents a generic type with an optional bound. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Generic(pub usize, pub Option); + /// Represents a type, potentially with span information. /// /// This struct wraps a TypeKind with optional span information, allowing @@ -163,8 +180,8 @@ pub struct TypeRegistry { /// Maps unknown type IDs to their current inferred concrete types. /// Types start at `Nothing`, and get bumped up when needed by the constraint. pub resolved_unknown: HashMap, - /// Instantiated generic types: (call_constraint_id, generic_id) -> TypeKind. - pub instantiated_generics: HashMap<(usize, usize), TypeKind>, + /// Instantiated generic types: call_constraint_id -> Vec<(Generic, TypeKind)>. + pub instantiated_generics: HashMap>, /// Current ID to use for new Unknown / Generic / Constraint id. pub next_id: usize, } @@ -198,20 +215,27 @@ impl TypeRegistry { /// /// `Ok(())` if registration is successful, or a `AnalyzerErrorKind` if a duplicate name is found. pub fn register_adt(&mut self, adt: &Adt) -> Result<(), Box> { + let check_name_validity = |name, span, registry: &TypeRegistry| { + if RESERVED_TYPE_MAP.contains_key(name) { + return Err(AnalyzerErrorKind::new_reserved_type(name, span)); + } + + if let Some(existing_span) = registry.spans.get(name) { + return Err(AnalyzerErrorKind::new_duplicate_adt( + name, + existing_span, + span, + )); + } + + Ok(()) + }; + match adt { Product { name, fields } => { let type_name = name.value.as_ref().clone(); + check_name_validity(&type_name, &name.span, self)?; - // Check for duplicate ADT names. - if let Some(existing_span) = self.spans.get(&type_name) { - return Err(AnalyzerErrorKind::new_duplicate_adt( - &type_name, - existing_span, - &name.span, - )); - } - - // Register the ADT fields. self.product_fields.insert( type_name.clone(), fields.iter().map(|field| *field.value.clone()).collect(), @@ -224,22 +248,13 @@ impl TypeRegistry { } Sum { name, variants } => { let enum_name = name.value.as_ref().clone(); - - // Check for duplicate ADT names. - if let Some(existing_span) = self.spans.get(&enum_name) { - return Err(AnalyzerErrorKind::new_duplicate_adt( - &enum_name, - existing_span, - &name.span, - )); - } + check_name_validity(&enum_name, &name.span, self)?; self.spans.insert(enum_name.clone(), name.clone().span); self.subtypes.entry(enum_name.clone()).or_default(); for variant in variants { let variant_adt = variant.value.as_ref(); - // Register each variant. self.register_adt(variant_adt)?; let variant_name = match variant_adt { @@ -247,7 +262,6 @@ impl TypeRegistry { Sum { name, .. } => name.value.as_ref(), }; - // Add variant as a subtype of the enum. if let Some(children) = self.subtypes.get_mut(&enum_name) { children.insert(variant_name.clone()); } @@ -395,84 +409,6 @@ impl TypeRegistry { }); } - /// Instantiates a type by replacing all generic type references with concrete types - /// specific to the given constraint. - /// - /// When a generic type is encountered, the function looks for an existing instantiation - /// for the (constraint_id, generic_id) pair. If none exists, it creates a new ascending - /// or descending unknown type and stores it in the instantiated_generics map. - /// - /// # Arguments - /// - /// * `ty` - The type to instantiate - /// * `constraint_id` - The ID of the constraint that needs this instantiation - /// * `ascending` - Whether unknown types should be created as ascending (true) or descending (false) - /// - /// # Returns - /// - /// A new type with all generic references replaced with constraint-specific instantiations. - pub fn instantiate_type(&mut self, ty: &Type, constraint_id: usize, ascending: bool) -> Type { - use TypeKind::*; - - let span = ty.span.clone(); - - let kind = match &*ty.value { - Generic(generic_id) => { - let key = (constraint_id, *generic_id); - - let next_id = if ascending { - self.new_unknown_asc() - } else { - self.new_unknown_desc() - }; - - self.instantiated_generics - .entry(key) - .or_insert(next_id) - .clone() - } - - // For composite types, recursively instantiate their component types. - Array(elem_type) => Array(self.instantiate_type(elem_type, constraint_id, ascending)), - - Closure(param_type, return_type) => Closure( - self.instantiate_type(param_type, constraint_id, !ascending), - self.instantiate_type(return_type, constraint_id, ascending), - ), - - Tuple(types) => { - let instantiated_types = types - .iter() - .map(|t| self.instantiate_type(t, constraint_id, ascending)) - .collect(); - Tuple(instantiated_types) - } - - Map(key_type, value_type) => Map( - self.instantiate_type(key_type, constraint_id, !ascending), - self.instantiate_type(value_type, constraint_id, ascending), - ), - - Optional(inner_type) => { - Optional(self.instantiate_type(inner_type, constraint_id, ascending)) - } - Stored(inner_type) => { - Stored(self.instantiate_type(inner_type, constraint_id, ascending)) - } - Costed(inner_type) => { - Costed(self.instantiate_type(inner_type, constraint_id, ascending)) - } - - // Primitive types, special types, and ADT references don't need instantiation. - _ => *ty.value.clone(), - }; - - Type { - value: kind.into(), - span: span.clone(), - } - } - /// Resolves any Unknown types to their concrete types. /// /// This method checks if a type is an Unknown variant and replaces it @@ -559,4 +495,41 @@ pub mod type_registry_tests { let result = registry.register_adt(&car2); assert!(result.is_err()); } + + #[test] + fn test_reserved_type_detection() { + let mut registry = TypeRegistry::default(); + + // Try to register an ADT with a reserved name + let concat_type = create_product_adt(CONCAT_TYPE, vec![]); + let result = registry.register_adt(&concat_type); + assert!(result.is_err()); + + if let Err(err) = result { + match *err { + AnalyzerErrorKind::ReservedType { name, .. } => { + assert_eq!(name, CONCAT_TYPE); + } + _ => panic!("Expected ReservedType error"), + } + } + + // Also try with Arithmetic + let arithmetic_type = create_product_adt(ARITHMERIC_TYPE, vec![]); + let result = registry.register_adt(&arithmetic_type); + assert!(result.is_err()); + + if let Err(err) = result { + match *err { + AnalyzerErrorKind::ReservedType { name, .. } => { + assert_eq!(name, ARITHMERIC_TYPE); + } + _ => panic!("Expected ReservedType error"), + } + } + + // Normal type registration should still work + let valid_type = create_product_adt("ValidType", vec![]); + assert!(registry.register_adt(&valid_type).is_ok()); + } } diff --git a/optd/src/dsl/analyzer/type_checks/solver.rs b/optd/src/dsl/analyzer/type_checks/solver.rs index 57fb0d8a..e4749098 100644 --- a/optd/src/dsl/analyzer/type_checks/solver.rs +++ b/optd/src/dsl/analyzer/type_checks/solver.rs @@ -1,9 +1,52 @@ +//! # Type Inference and Constraint Solving +//! +//! This module implements the type inference system for our DSL. Type inference +//! is crucial for both ergonomics and correctness in the optimizer, as it allows +//! the engine to know which types are `Logical` or `Physical` (which behave +//! differently), and to verify if field accesses and function calls are valid. +//! +//! ## Type Inference Strategy +//! +//! The type inference works in three phases: +//! +//! 1. **Initial Type Creation**: During the `AST -> HIR` transformation, +//! we create and add all implicit and explicit type information from the program +//! (e.g., literals like `1` or `"hello"`, function annotations, etc.). For unknown +//! types, we generate a new ID and assign the type to either `UnknownDesc` (for +//! closure parameters and map keys) or `UnknownAsc` (for everything else). +//! +//! 2. **Constraint Generation**: Constraints are generated in the `generate.rs` file, +//! which also performs scope-checking. Constraints indicate subtype relationships, +//! field accesses, and function calls. For example, `let a: Logical = expr` generates +//! the constraint `Logical :> typeof(expr)`. +//! +//! 3. **Constraint Solving**: The final step uses a constraint solver that iteratively +//! refines unknown types until reaching a fixed point where no more refinements +//! can be made. +//! +//! ## Constraint Solving Algorithm +//! +//! The constraint solving algorithm makes monotonic progress by tracking whether any +//! types changed during each iteration and continuing until reaching a fixed point. +//! Unknown types are refined according to their variance: +//! +//! - `UnknownAsc`: These types start at `Nothing` and ascend up the type hierarchy +//! as needed. When encountered as a parent, they are updated to the least upper +//! bound (LUB) of themselves and the child type. +//! +//! - `UnknownDesc`: These types start at `Universe` and descend down the type +//! hierarchy as needed. When encountered as a child, they are updated to the +//! greatest lower bound (GLB) of themselves and the parent type. +//! +//! The solver continues until no more changes can be made, at which point it either +//! reports success or returns the most relevant type error. + use super::registry::{Constraint, Type, TypeRegistry}; use crate::dsl::{ analyzer::{ errors::AnalyzerErrorKind, hir::{Identifier, Pattern, PatternKind, TypedSpan}, - type_checks::registry::TypeKind, + type_checks::registry::{Generic, TypeKind}, }, utils::span::Span, }; @@ -12,9 +55,19 @@ use std::mem; impl TypeRegistry { /// Resolves all collected constraints and fills in the concrete types. /// - /// This method iterates through all constraints, checking subtype relationships - /// and refining unknown types until either all constraints are satisfied or - /// a constraint cannot be satisfied and no more progress can be made. + /// This method implements the main constraint solving algorithm, iterating through + /// all constraints and refining unknown types until either all constraints are + /// satisfied or we reach a fixed point where no more progress can be made. + /// + /// # Algorithm + /// + /// 1. The algorithm iteratively processes all constraints, tracking whether any + /// type refinements were made in each iteration. + /// 2. For each constraint, we attempt to satisfy it by refining unknown types. + /// 3. We continue until reaching a fixed point (no more changes) or until all + /// constraints are satisfied. + /// 4. If constraints remain unsatisfied at the fixed point, we return the most + /// relevant type error. /// /// # Returns /// @@ -56,7 +109,20 @@ impl TypeRegistry { } } - /// Checks if a single constraint is satisfied. + /// Checks if a single constraint is satisfied, potentially refining unknown types. + /// + /// This method dispatches to type-specific constraint checkers based on the + /// constraint kind (subtype, call, scrutinee, or field access). + /// + /// # Arguments + /// + /// * `constraint` - The constraint to check + /// * `changed` - Mutable flag that is set to true if any types were refined + /// + /// # Returns + /// + /// * `Ok(())` if the constraint is satisfied or could be satisfied through refinement. + /// * `Err(error)` if the constraint cannot be satisfied. fn check_constraint( &mut self, constraint: &Constraint, @@ -86,6 +152,21 @@ impl TypeRegistry { } } + /// Checks if a subtyping constraint is satisfied, refining unknown types if needed. + /// + /// This method verifies that `child` is a subtype of `parent_ty`, potentially + /// refining unknown types to satisfy this relationship. + /// + /// # Arguments + /// + /// * `child` - The child type with its source location + /// * `parent_ty` - The parent type + /// * `changed` - Mutable flag that is set to true if any types were refined + /// + /// # Returns + /// + /// * `Ok(())` if the subtyping relationship is satisfied or could be satisfied. + /// * `Err(error)` if the subtyping relationship cannot be satisfied. fn check_subtype_constraint( &mut self, child: &TypedSpan, @@ -104,6 +185,136 @@ impl TypeRegistry { }) } + /// Instantiates a type by replacing all generic type references with concrete types + /// specific to the given constraint. + /// + /// This method handles generics by creating unique type instantiations for each + /// constraint context. When encountering a generic type, it either retrieves an + /// existing instantiation or creates a new one based on the variance direction. + /// If the generic has a bound, it ensures the instantiated type satisfies that bound. + /// + /// # Arguments + /// + /// * `ty` - The type to instantiate + /// * `constraint_id` - The ID of the constraint that needs this instantiation + /// * `span` - The span where the expression occurs for error reporting + /// * `ascending` - Whether unknown types should be created as ascending (true) or descending (false) + /// + /// # Returns + /// + /// A Result containing either the instantiated type or an error if bound constraints cannot be satisfied. + fn instantiate_type( + &mut self, + ty: &Type, + constraint_id: usize, + span: &Span, + ascending: bool, + ) -> Result> { + use TypeKind::*; + + let kind = match &*ty.value { + Gen(generic @ Generic(_, bound)) => { + // Try to find an existing instantiation for this generic in this constraint. + let existing = self + .instantiated_generics + .get(&constraint_id) + .and_then(|vec| { + vec.iter() + .find(|(g, _)| g == generic) + .map(|(_, ty)| ty.clone()) + }); + + // If found, return the existing instantiation. + if let Some(existing_type) = existing { + existing_type + } else { + // Otherwise, create a new instantiation. + let next_id = if ascending { + self.new_unknown_asc() + } else { + self.new_unknown_desc() + }; + + // Get or create the vector for this constraint_id. + let entry = self.instantiated_generics.entry(constraint_id).or_default(); + entry.push((generic.clone(), next_id.clone())); + + // Check the bound constraint. + if let Some(bound) = bound { + let next_id_ty: Type = next_id.clone().into(); + self.check_subtype_constraint( + &TypedSpan::new(next_id_ty, span.clone()), + bound, + &mut false, // The unknown type appears no-where else as it has been introduced here. + ) + .unwrap(); + } + + next_id + } + } + + // For composite types, recursively instantiate their component types. + Array(elem_type) => { + Array(self.instantiate_type(elem_type, constraint_id, span, ascending)?) + } + + Closure(param_type, return_type) => Closure( + self.instantiate_type(param_type, constraint_id, span, !ascending)?, + self.instantiate_type(return_type, constraint_id, span, ascending)?, + ), + + Tuple(types) => { + let instantiated_types = types + .iter() + .map(|t| self.instantiate_type(t, constraint_id, span, ascending)) + .collect::, _>>()?; + Tuple(instantiated_types) + } + + Map(key_type, value_type) => Map( + self.instantiate_type(key_type, constraint_id, span, !ascending)?, + self.instantiate_type(value_type, constraint_id, span, ascending)?, + ), + + Optional(inner_type) => { + Optional(self.instantiate_type(inner_type, constraint_id, span, ascending)?) + } + Stored(inner_type) => { + Stored(self.instantiate_type(inner_type, constraint_id, span, ascending)?) + } + Costed(inner_type) => { + Costed(self.instantiate_type(inner_type, constraint_id, span, ascending)?) + } + + // Primitive types, special types, and ADT references don't need instantiation. + _ => *ty.value.clone(), + }; + + Ok(Type { + value: kind.into(), + span: ty.span.clone(), + }) + } + + /// Checks a function call constraint, ensuring parameter and return types match. + /// + /// This method handles function calls, array indexing, and map lookups by verifying + /// that the arguments match the expected parameter types and that the return type + /// is compatible with the expected output type. + /// + /// # Arguments + /// + /// * `constraint_id` - The ID of the constraint for instantiating generic functions + /// * `inner` - The function/container being called/indexed + /// * `args` - The arguments/indices provided to the call + /// * `outer` - The expected return type + /// * `changed` - Mutable flag that is set to true if any types were refined + /// + /// # Returns + /// + /// * `Ok(())` if the call constraint is satisfied or could be satisfied. + /// * `Err(error)` if the call constraint cannot be satisfied. fn check_call_constraint( &mut self, constraint_id: usize, @@ -118,14 +329,27 @@ impl TypeRegistry { TypeKind::Nothing => Ok(()), TypeKind::Closure(param, ret) => { - // Initiatialize potential generics. - let param = self.instantiate_type(param, constraint_id, true); - let ret = self.instantiate_type(ret, constraint_id, true); - - let (param_len, param_types) = match &*param { - TypeKind::Tuple(types) => (types.len(), types.to_vec()), + // Instantiate potential generics. + let ret = self.instantiate_type(ret, constraint_id, &inner.span, true)?; + + let (param_len, param_types) = match &*param.value { + TypeKind::Tuple(types) => { + let instantiated_types = types + .iter() + .zip(args.iter()) + .map(|(param_type, arg)| { + self.instantiate_type(param_type, constraint_id, &arg.span, true) + }) + .collect::, _>>()?; + + (types.len(), instantiated_types) + } TypeKind::Unit => (0, vec![]), - _ => (1, vec![param.clone()]), + _ => { + let param_inst = + self.instantiate_type(param, constraint_id, &args[0].span, true)?; + (1, vec![param_inst]) + } }; if param_len != args.len() { @@ -173,6 +397,26 @@ impl TypeRegistry { } } + /// Helper method for checking indexing operations on arrays and maps. + /// + /// This method verifies that an indexing operation is valid by checking that: + /// 1. Exactly one argument/index is provided + /// 2. The index type matches the expected key type + /// 3. The resulting element type (wrapped in Optional) is compatible with the expected output + /// + /// # Arguments + /// + /// * `span` - The source location of the indexing operation + /// * `args` - The arguments/indices provided (should be exactly one) + /// * `key_type` - The expected type of the index + /// * `elem_type` - The type of elements in the container + /// * `outer_ty` - The expected output type + /// * `changed` - Mutable flag that is set to true if any types were refined + /// + /// # Returns + /// + /// * `Ok(())` if the indexing constraint is satisfied or could be satisfied. + /// * `Err(error)` if the indexing constraint cannot be satisfied. fn check_indexable( &mut self, span: &Span, @@ -202,6 +446,23 @@ impl TypeRegistry { ) } + /// Checks a field access constraint, ensuring the field exists and has the correct type. + /// + /// This method verifies field access on ADTs and tuples by checking that: + /// 1. For ADTs: the field exists and its type is compatible with the expected output + /// 2. For tuples: the _N syntax is used with a valid index, and the type matches + /// + /// # Arguments + /// + /// * `inner` - The object whose field is being accessed + /// * `field` - The name of the field being accessed + /// * `outer` - The expected type of the field + /// * `changed` - Mutable flag that is set to true if any types were refined + /// + /// # Returns + /// + /// * `Ok(())` if the field access constraint is satisfied or could be satisfied. + /// * `Err(error)` if the field access constraint cannot be satisfied. fn check_field_access_constraint( &mut self, inner: &TypedSpan, @@ -259,6 +520,22 @@ impl TypeRegistry { } } + /// Checks a pattern matching constraint, ensuring the scrutinee matches the pattern. + /// + /// This method handles various pattern matching constructs including binding, + /// struct destructuring, array decomposition, and wildcards by ensuring type + /// compatibility between the scrutinee and pattern. + /// + /// # Arguments + /// + /// * `scrutinee` - The expression being matched + /// * `pattern` - The pattern to match against + /// * `changed` - Mutable flag that is set to true if any types were refined + /// + /// # Returns + /// + /// * `Ok(())` if the pattern matching constraint is satisfied or could be satisfied. + /// * `Err(error)` if the pattern matching constraint cannot be satisfied. fn check_scrutinee_constraint( &mut self, scrutinee: &TypedSpan, @@ -322,7 +599,7 @@ impl TypeRegistry { // First check scrutinee against wildcard type. self.check_subtype_constraint(scrutinee, &pattern.metadata.ty, changed)?; - // Then check wildcard pattern against scrutinee type + // Then check wildcard pattern against scrutinee type. self.check_subtype_constraint(&pattern.metadata, &scrutinee_ty, changed) } } @@ -623,7 +900,7 @@ mod tests { | BoolLiteral(value: Bool) \ StringLiteral(value: String) | BinaryOp = - | Arithmetic = + | Arith = | Add(left: Expression, right: Expression) | Subtract(left: Expression, right: Expression) | Multiply(left: Expression, right: Expression) @@ -1039,4 +1316,143 @@ mod tests { "Generic type constraints should have failed type inference" ); } + + #[test] + fn test_generic_with_supertype_bound() { + let source = r#" + // Small ADT type hierarchy + data Animal = + | Mammal = + | Dog(name: String) + \ Cat(name: String) + \ Bird(name: String) + + // Generic function with a bound requiring type to be an Animal + fn describe(entity: E): String = + match entity + | Dog(name) -> name ++ " is a dog" + | Cat(name) -> name ++ " is a cat" + \ Bird(name) -> name ++ " is a bird" + + + fn main(): String = + let + // Create specific animal types + dog = Dog("Rex"), + cat = Cat("Whiskers"), + bird = Bird("Tweety"), + in + describe(dog) + + "#; + + let result = run_type_inference(source, "generic_with_supertype_bound.opt"); + assert!( + result.is_ok(), + "Generic function with supertype bound failed type inference: {:?}", + result + ); + } + + #[test] + fn test_pairs_to_map() { + let source = r#" + // Convert array of key-value pairs to a map + fn (pairs: [(K, V)]) to_map(): {K: V} = match pairs + | [head .. tail] -> {head#_0: head#_1} ++ tail.to_map() + \ [] -> {} + + fn main(): {String : I64} = + let + // Create pairs and convert to map + pairs = [("a", 1), ("b", 2), ("c", 3)], + map = pairs.to_map(), + in + map + "#; + + let result = run_type_inference(source, "simple_to_map.opt"); + assert!( + result.is_ok(), + "Basic pairs to map conversion failed: {:?}", + result + ); + } + + #[test] + fn test_pairs_to_map_without_eqhash() { + let source = r#" + // Convert array of key-value pairs to a map, but missing EqHash bound + fn (pairs: [(K, V)]) to_map(): {K: V} = match pairs + | [head .. tail] -> {head#_0: head#_1} ++ tail.to_map() + \ [] -> {} + + fn main(): {String: I64} = + let + // Create pairs and convert to map + pairs = [("a", 1), ("b", 2), ("c", 3)], + map = pairs.to_map() + in + map + "#; + + let result = run_type_inference(source, "missing_eqhash.opt"); + assert!( + result.is_err(), + "Map creation without EqHash bound should have failed" + ); + } + + #[test] + fn test_with_exact_source() { + let source = r#" + // Convert array of key-value pairs to a map with EqHash constraint + fn (pairs: [(K, V)]) to_map(): {K: V} = match pairs + | [head .. tail] -> {head#_0: head#_1} ++ tail.to_map() + \ [] -> {} + + // Simple filter function for arrays + fn (array: [T]) filter(pred: T -> Bool): [T] = match array + | [] -> [] + \ [x .. xs] -> + if pred(x) then [x] ++ xs.filter(pred) + else xs.filter(pred) + + // Simple map function + fn (array: [T]) map(f: T -> U): [U] = match array + | [] -> [] + \ [x .. xs] -> [f(x)] ++ xs.map(f) + + // Data type for our test + data User(name: String, age: I64) + + fn main(): {String: I64} = + let + // Create some users + users = [ + User("Alice", 25), + User("Bob", 17), + User("Charlie", 30), + User("Diana", 15) + ], + + // Filter for adults + adults = users.filter((u: User) -> u#age >= 18), + + // Create name -> age map + name_age_pairs = adults.map((u: User) -> (u#name, u#age)), + + // Convert to map + age_map = name_age_pairs.to_map() + in + age_map + "#; + + let result = run_type_inference(source, "exact_source_test.opt"); + assert!( + result.is_ok(), + "Type inference with exact source code failed: {:?}", + result + ); + } } diff --git a/optd/src/dsl/analyzer/type_checks/subtype.rs b/optd/src/dsl/analyzer/type_checks/subtype.rs index aea6e88c..a629b89e 100644 --- a/optd/src/dsl/analyzer/type_checks/subtype.rs +++ b/optd/src/dsl/analyzer/type_checks/subtype.rs @@ -1,5 +1,5 @@ use super::registry::{LOGICAL_TYPE, PHYSICAL_TYPE, Type, TypeRegistry}; -use crate::dsl::analyzer::type_checks::registry::TypeKind; +use crate::dsl::analyzer::type_checks::registry::{Generic, TypeKind}; use std::collections::HashSet; impl TypeRegistry { @@ -131,9 +131,26 @@ impl TypeRegistry { self.is_subtype_inner(child, &self.resolve_type(parent), memo, has_changed) } - // Generics only match if they have strictly the same name. - // Bounded generics are not yet supported. - (Generic(gen1), Generic(gen2)) if gen1 == gen2 => true, + // Generics only match if they have strictly the same name, + // or if their bounds are compatible. + (Gen(Generic(id1, bound1)), Gen(Generic(id2, bound2))) => { + if id1 == id2 { + true + } else if let (Some(b1), Some(b2)) = (bound1, bound2) { + self.is_subtype_inner(b1, b2, memo, has_changed) + } else { + false + } + } + + // A type is a subtype of a bounded generic if it's a subtype of the bound, + // and vice versa. + (_, Gen(Generic(_, Some(bound)))) => { + self.is_subtype_inner(child, bound, memo, has_changed) + } + (Gen(Generic(_, Some(bound))), _) => { + self.is_subtype_inner(bound, parent, memo, has_changed) + } // Stored and Costed type handling. (Stored(child_inner), Stored(parent_inner)) => { @@ -208,8 +225,6 @@ impl TypeRegistry { (Optional(child_ty), Optional(parent_ty)) => { self.is_subtype_inner(child_ty, parent_ty, memo, has_changed) } - // None <: Optional[Nothing]. - (None, Optional(_)) => true, // Likewise, T <: Optional[T]. (_, Optional(parent_inner)) => { self.is_subtype_inner(child, parent_inner, memo, has_changed) @@ -227,7 +242,6 @@ impl TypeRegistry { (String, EqHash) => true, (Bool, EqHash) => true, (Unit, EqHash) => true, - (None, EqHash) => true, (Optional(inner), EqHash) => { self.is_subtype_inner(inner, &EqHash.into(), memo, has_changed) } @@ -723,64 +737,101 @@ mod tests { } #[test] - fn test_generic_subtyping() { + fn test_generic_subtyping_with_bounds() { let mut reg = TypeRegistry::default(); - // Generics are only subtypes of themselves (same name) - assert!(reg.is_subtype(&Generic(0).into(), &Generic(0).into())); + // Set up a type hierarchy for testing + let animal = create_product_adt("Animal", vec![]); + let dog = create_product_adt("Dog", vec![]); + let cat = create_product_adt("Cat", vec![]); + + let animals_enum = create_sum_adt("Animals", vec![animal, dog.clone(), cat.clone()]); + reg.register_adt(&animals_enum).unwrap(); - // Different named generics are not subtypes - assert!(!reg.is_subtype(&Generic(0).into(), &Generic(1).into())); + // Create generic types with and without bounds - using unique IDs + let generic_1 = Gen(Generic(1, Option::None)).into(); + let generic_2 = Gen(Generic(2, Option::None)).into(); - // All generics are subtypes of Universe - assert!(reg.is_subtype(&Generic(0).into(), &Universe.into())); + // Create bounded generics with *unique* IDs - bounds must be ADTs only + let generic_3_animals = Gen(Generic(3, Some(Adt("Animals".to_string()).into()))).into(); + let generic_4_animals = Gen(Generic(4, Some(Adt("Animals".to_string()).into()))).into(); + let generic_5_dog = Gen(Generic(5, Some(Adt("Dog".to_string()).into()))).into(); - // Nothing is a subtype of any generic - assert!(reg.is_subtype(&Nothing.into(), &Generic(0).into())); + // Test 1: Same generic ID is a subtype of itself + assert!(reg.is_subtype(&generic_1, &generic_1)); + assert!(reg.is_subtype(&generic_3_animals, &generic_3_animals)); - // Generic is not a subtype of concrete types - assert!(!reg.is_subtype(&Generic(0).into(), &I64.into())); + // Test 2: Different generic IDs with compatible bounds + // Gen<5: Dog> <: Gen<4: Animals> because Dog <: Animals + assert!(reg.is_subtype(&generic_5_dog, &generic_4_animals)); - // Concrete types are not subtypes of generics - assert!(!reg.is_subtype(&I64.into(), &Generic(0).into())); + // Test 3: Different generic IDs with incompatible bounds + // Gen<4: Animals> !<: Gen<5: Dog> because Animals !<: Dog + assert!(!reg.is_subtype(&generic_4_animals, &generic_5_dog)); - // Test with generic in container types - assert!(reg.is_subtype( - &Array(Generic(0).into()).into(), - &Array(Generic(0).into()).into() - )); + // Test 4: Different generic IDs with no bounds are never subtypes + assert!(!reg.is_subtype(&generic_1, &generic_2)); - // Different generics in container types - assert!(!reg.is_subtype( - &Array(Generic(0).into()).into(), - &Array(Generic(1).into()).into() - )); + // Test 5: Concrete type vs generic with bound + let dog_type = Adt("Dog".to_string()).into(); + let animals_type = Adt("Animals".to_string()).into(); + + // Dog <: Gen<4: Animals> because Dog <: Animals + assert!(reg.is_subtype(&dog_type, &generic_4_animals)); + + // Animals !<: Gen<5: Dog> because Animals !<: Dog + assert!(!reg.is_subtype(&animals_type, &generic_5_dog)); + + // Test 6: Generic with bound vs concrete type (bidirectional check) + // Gen<5: Dog> <: Animals because Dog <: Animals + assert!(reg.is_subtype(&generic_5_dog, &animals_type)); + + // Gen<4: Animals> !<: Dog because Animals !<: Dog + assert!(!reg.is_subtype(&generic_4_animals, &dog_type)); } #[test] - fn test_none_subtyping() { + fn test_container_generic_subtyping() { let mut reg = TypeRegistry::default(); - // Test None as a subtype of any Optional type - assert!(reg.is_subtype(&None.into(), &Optional(I64.into()).into())); - assert!(reg.is_subtype(&None.into(), &Optional(String.into()).into())); - assert!(reg.is_subtype(&None.into(), &Optional(Bool.into()).into())); - assert!(reg.is_subtype(&None.into(), &Optional(F64.into()).into())); - assert!(reg.is_subtype(&None.into(), &Optional(Unit.into()).into())); + // Set up a type hierarchy for testing + let animal = create_product_adt("Animal", vec![]); + let dog = create_product_adt("Dog", vec![]); + let cat = create_product_adt("Cat", vec![]); + + let animals_enum = create_sum_adt("Animals", vec![animal, dog.clone(), cat.clone()]); + reg.register_adt(&animals_enum).unwrap(); - // Test None with complex Optional types - assert!(reg.is_subtype(&None.into(), &Optional(Array(I64.into()).into()).into())); + // Create generic types with bounds + let generic_animals: Type = Gen(Generic(1, Some(Adt("Animals".to_string()).into()))).into(); + let generic_dog: Type = Gen(Generic(2, Some(Adt("Dog".to_string()).into()))).into(); - // Test that None is not a subtype of non-Optional types - assert!(!reg.is_subtype(&None.into(), &I64.into())); - assert!(!reg.is_subtype(&None.into(), &String.into())); + // Test container types with generics - // None is still a subtype of Universe (as all types are) - assert!(reg.is_subtype(&None.into(), &Universe.into())); + // Array> <: Array> because Gen <: Gen + assert!(reg.is_subtype( + &Array(generic_dog.clone()).into(), + &Array(generic_animals.clone()).into() + )); - // None is not equal to Nothing - assert!(!reg.is_subtype(&None.into(), &Nothing.into())); - assert!(reg.is_subtype(&Nothing.into(), &None.into())); + // Map> <: Map> because Gen <: Gen + assert!(reg.is_subtype( + &Map(String.into(), generic_dog.clone()).into(), + &Map(String.into(), generic_animals.clone()).into() + )); + + // Map, String> <: Map, String> because Gen <: Gen + // (contravariance for Map keys) + assert!(reg.is_subtype( + &Map(generic_animals.clone(), String.into()).into(), + &Map(generic_dog.clone(), String.into()).into() + )); + + // Function with generics: (Gen -> String) <: (Gen -> String) + assert!(reg.is_subtype( + &Closure(generic_animals.clone(), String.into()).into(), + &Closure(generic_dog.clone(), String.into()).into() + )); } #[test] @@ -928,7 +979,6 @@ mod tests { assert!(!reg.is_subtype(&Bool.into(), &Concat.into())); assert!(!reg.is_subtype(&F64.into(), &Concat.into())); assert!(!reg.is_subtype(&Unit.into(), &Concat.into())); - assert!(!reg.is_subtype(&None.into(), &Concat.into())); assert!(!reg.is_subtype( &Tuple(vec![I64.into(), String.into()]).into(), &Concat.into() @@ -955,7 +1005,6 @@ mod tests { assert!(reg.is_subtype(&String.into(), &EqHash.into())); assert!(reg.is_subtype(&Bool.into(), &EqHash.into())); assert!(reg.is_subtype(&Unit.into(), &EqHash.into())); - assert!(reg.is_subtype(&None.into(), &EqHash.into())); // Test tuple types with all EqHash elements assert!(reg.is_subtype( @@ -995,7 +1044,6 @@ mod tests { assert!(!reg.is_subtype(&String.into(), &Arithmetic.into())); assert!(!reg.is_subtype(&Bool.into(), &Arithmetic.into())); assert!(!reg.is_subtype(&Unit.into(), &Arithmetic.into())); - assert!(!reg.is_subtype(&None.into(), &Arithmetic.into())); assert!(!reg.is_subtype( &Tuple(vec![I64.into(), F64.into()]).into(), &Arithmetic.into() diff --git a/optd/src/dsl/parser/ast.rs b/optd/src/dsl/parser/ast.rs index 6c57a7aa..74b94d04 100644 --- a/optd/src/dsl/parser/ast.rs +++ b/optd/src/dsl/parser/ast.rs @@ -239,7 +239,7 @@ pub struct Function { /// Name of the function with source location pub name: Spanned, /// Optional generic type parameters - pub type_params: Vec>, + pub type_params: Vec<(Spanned, Option>)>, /// Optional receiver for method-style functions (self parameter) pub receiver: Option>, /// Optional parameters list diff --git a/optd/src/dsl/parser/function.rs b/optd/src/dsl/parser/function.rs index 58151a02..7635a8df 100644 --- a/optd/src/dsl/parser/function.rs +++ b/optd/src/dsl/parser/function.rs @@ -24,12 +24,14 @@ use chumsky::{ /// /// 1. Regular functions with implementation: /// ```ignore -/// [annotations] fn (receiver): name(params): ReturnType = body +/// [annotations] +/// fn (receiver): name(params): ReturnType = body /// ``` /// /// 2. Extern functions (declarations without implementation): /// ```ignore -/// [annotations] fn (receiver): name(params): ReturnType +/// [annotations] +/// fn (receiver): name(params): ReturnType /// ``` /// /// # Components @@ -38,10 +40,11 @@ use chumsky::{ /// - Used to mark special properties or behaviors /// - Example: `[rust]` might indicate the function is implemented in Rust /// -/// * ``: Optional generic type parameters +/// * ``: Optional generic type parameters with optional bounds /// - Enclosed in angle brackets /// - Used for generic functions /// - Example: `` for type parameters T and U +/// - Example with bounds: `` for constrained type parameters /// /// * `(receiver)`: Optional receiver parameter for method-style functions /// - Similar to `self` in Rust or `this` in other languages @@ -75,9 +78,14 @@ use chumsky::{ /// fn identity(x: T): T = x /// ``` /// -/// Extern function with generic parameters: +/// Generic function with bound: /// ```ignore -/// [rust] fn native_map_get(map: {K: V}, key: K): V? +/// fn max(a: T, b: T): T = if a > b then a else b +/// ``` +/// +/// Extern function with generic parameters and bounds: +/// ```ignore +/// [rust] fn native_map_get(map: {K: V}, key: K): V? /// ``` /// /// # Error Recovery @@ -109,9 +117,13 @@ pub fn function_parser() let ident_parser = select! { Token::TermIdent(name) => name }.map_with_span(Spanned::new); let type_ident_parser = select! { Token::TypeIdent(name) => name }.map_with_span(Spanned::new); - // Parse optional generic type parameters like + // Parse type parameter with optional bound + let type_param_parser = + type_ident_parser.then(just(Token::Colon).ignore_then(type_parser()).or_not()); + + // Parse optional generic type parameters like let type_params = delimited_parser( - type_ident_parser + type_param_parser .separated_by(just(Token::Comma)) .allow_trailing(), Token::Less, @@ -264,7 +276,9 @@ mod tests { // Check type parameters assert_eq!(func.value.type_params.len(), 1); - assert_eq!(*func.value.type_params[0].value, "T"); + let (type_param, bound) = &func.value.type_params[0]; + assert_eq!(*type_param.value, "T"); + assert!(bound.is_none()); // Check parameter assert!(func.value.params.is_some()); @@ -283,6 +297,47 @@ mod tests { } } + #[test] + fn test_generic_function_with_bounds() { + let input = "fn max(a: T, b: T): T = if a > b then a else b"; + let (result, errors) = parse_function(input); + + assert!(result.is_some(), "Expected successful parse"); + assert!(errors.is_empty(), "Expected no errors"); + + if let Some(func) = result { + assert_eq!(*func.value.name.value, "max"); + + // Check type parameters + assert_eq!(func.value.type_params.len(), 1); + let (type_param, bound) = &func.value.type_params[0]; + assert_eq!(*type_param.value, "T"); + assert!(bound.is_some()); + if let Some(bound_val) = bound { + assert_eq!(*bound_val.value, Type::Identifier("Comparable".to_string())); + } + + // Check parameters + assert!(func.value.params.is_some()); + let params = func.value.params.as_ref().unwrap(); + assert_eq!(params.len(), 2); + assert_eq!(*params[0].value.name.value, "a"); + assert!( + matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "T") + ); + assert_eq!(*params[1].value.name.value, "b"); + assert!( + matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "T") + ); + + // Check return type + assert!(matches!(*func.value.return_type.value, Type::Identifier(name) if name == "T")); + + // Check body + assert!(func.value.body.is_some()); + } + } + #[test] fn test_generic_function_multiple_type_params() { let input = "fn mapGet(map: {K: V}, key: K): V? = map.get(key)"; @@ -296,8 +351,68 @@ mod tests { // Check type parameters assert_eq!(func.value.type_params.len(), 2); - assert_eq!(*func.value.type_params[0].value, "K"); - assert_eq!(*func.value.type_params[1].value, "V"); + let (k_param, k_bound) = &func.value.type_params[0]; + let (v_param, v_bound) = &func.value.type_params[1]; + assert_eq!(*k_param.value, "K"); + assert_eq!(*v_param.value, "V"); + assert!(k_bound.is_none()); + assert!(v_bound.is_none()); + + // Check parameters + assert!(func.value.params.is_some()); + let params = func.value.params.as_ref().unwrap(); + assert_eq!(params.len(), 2); + + // Check first parameter (map: {K: V}) + assert_eq!(*params[0].value.name.value, "map"); + if let Type::Map(key_ty, val_ty) = &*params[0].value.ty.value { + assert!(matches!(*key_ty.clone().value, Type::Identifier(name) if name == "K")); + assert!(matches!(*val_ty.clone().value, Type::Identifier(name) if name == "V")); + } else { + panic!("Expected Map type for first parameter"); + } + + // Check second parameter (key: K) + assert_eq!(*params[1].value.name.value, "key"); + assert!( + matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "K") + ); + + // Check return type (V?) + if let Type::Questioned(inner) = &*func.value.return_type.value { + assert!(matches!(*inner.clone().value, Type::Identifier(name) if name == "V")); + } else { + panic!("Expected Optional return type"); + } + } + } + + #[test] + fn test_generic_function_multiple_type_params_with_bounds() { + let input = "fn mapGet(map: {K: V}, key: K): V? = map.get(key)"; + let (result, errors) = parse_function(input); + + assert!(result.is_some(), "Expected successful parse"); + assert!(errors.is_empty(), "Expected no errors"); + + if let Some(func) = result { + assert_eq!(*func.value.name.value, "mapGet"); + + // Check type parameters + assert_eq!(func.value.type_params.len(), 2); + + // Check K parameter with Hashable bound + let (k_param, k_bound) = &func.value.type_params[0]; + assert_eq!(*k_param.value, "K"); + assert!(k_bound.is_some()); + if let Some(bound) = k_bound { + assert_eq!(*bound.value, Type::Identifier("Hashable".to_string())); + } + + // Check V parameter with no bound + let (v_param, v_bound) = &func.value.type_params[1]; + assert_eq!(*v_param.value, "V"); + assert!(v_bound.is_none()); // Check parameters assert!(func.value.params.is_some()); @@ -341,9 +456,68 @@ mod tests { // Check type parameters assert_eq!(func.value.type_params.len(), 3); - assert_eq!(*func.value.type_params[0].value, "A"); - assert_eq!(*func.value.type_params[1].value, "B"); - assert_eq!(*func.value.type_params[2].value, "C"); + assert_eq!(*func.value.type_params[0].0.value, "A"); + assert_eq!(*func.value.type_params[1].0.value, "B"); + assert_eq!(*func.value.type_params[2].0.value, "C"); + assert!(func.value.type_params[0].1.is_none()); + assert!(func.value.type_params[1].1.is_none()); + assert!(func.value.type_params[2].1.is_none()); + + // Check parameters + assert!(func.value.params.is_some()); + let params = func.value.params.as_ref().unwrap(); + assert_eq!(params.len(), 2); + assert_eq!(*params[0].value.name.value, "a"); + assert!( + matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "A") + ); + assert_eq!(*params[1].value.name.value, "b"); + assert!( + matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "B") + ); + + // Check return type + assert!(matches!(*func.value.return_type.value, Type::Identifier(name) if name == "C")); + + // Check body is None + assert!(func.value.body.is_none()); + } + } + + #[test] + fn test_generic_extern_function_with_bounds() { + let input = "fn externalFunc(a: A, b: B): C"; + let (result, errors) = parse_function(input); + + assert!(result.is_some(), "Expected successful parse"); + assert!(errors.is_empty(), "Expected no errors"); + + if let Some(func) = result { + assert_eq!(*func.value.name.value, "externalFunc"); + + // Check type parameters with bounds + assert_eq!(func.value.type_params.len(), 3); + + // Check A parameter with Serializable bound + let (a_param, a_bound) = &func.value.type_params[0]; + assert_eq!(*a_param.value, "A"); + assert!(a_bound.is_some()); + if let Some(bound) = a_bound { + assert_eq!(*bound.value, Type::Identifier("Serializable".to_string())); + } + + // Check B parameter with Printable bound + let (b_param, b_bound) = &func.value.type_params[1]; + assert_eq!(*b_param.value, "B"); + assert!(b_bound.is_some()); + if let Some(bound) = b_bound { + assert_eq!(*bound.value, Type::Identifier("Printable".to_string())); + } + + // Check C parameter with no bound + let (c_param, c_bound) = &func.value.type_params[2]; + assert_eq!(*c_param.value, "C"); + assert!(c_bound.is_none()); // Check parameters assert!(func.value.params.is_some()); @@ -708,4 +882,52 @@ mod tests { assert!(func.value.body.is_none()); // Extern function has no body } } + + #[test] + fn test_mixed_bounds_and_no_bounds() { + let input = "fn process(a: A, b: B, c: C): B = b"; + let (result, errors) = parse_function(input); + + assert!(result.is_some(), "Expected successful parse"); + assert!(errors.is_empty(), "Expected no errors"); + + if let Some(func) = result { + assert_eq!(*func.value.name.value, "process"); + + // Check type parameters with mixed bounds + assert_eq!(func.value.type_params.len(), 3); + + // Check A parameter with Serializable bound + let (a_param, a_bound) = &func.value.type_params[0]; + assert_eq!(*a_param.value, "A"); + assert!(a_bound.is_some()); + if let Some(bound) = a_bound { + assert_eq!(*bound.value, Type::Identifier("Serializable".to_string())); + } + + // Check B parameter with no bound + let (b_param, b_bound) = &func.value.type_params[1]; + assert_eq!(*b_param.value, "B"); + assert!(b_bound.is_none()); + + // Check C parameter with Printable bound + let (c_param, c_bound) = &func.value.type_params[2]; + assert_eq!(*c_param.value, "C"); + assert!(c_bound.is_some()); + if let Some(bound) = c_bound { + assert_eq!(*bound.value, Type::Identifier("Printable".to_string())); + } + + // Check parameterss + assert!(func.value.params.is_some()); + let params = func.value.params.as_ref().unwrap(); + assert_eq!(params.len(), 3); + + // Check return type + assert!(matches!(*func.value.return_type.value, Type::Identifier(name) if name == "B")); + + // Check body is present for regular function + assert!(func.value.body.is_some()); + } + } } From e7d28f09d970c99e805698e9ac0c9ff9a4727922 Mon Sep 17 00:00:00 2001 From: AlSchlo <79570602+AlSchlo@users.noreply.github.com> Date: Sat, 3 May 2025 22:43:43 +0200 Subject: [PATCH 7/8] (feat) Execute DSL (#99) ## Overview DSL functions can now be executed through the CLI with the `[run]` annotation. Try it out! `cargo run --bin optd-cli -- run-functions [path]/logical_rules.opt` ## Known Issues - Stack overflows *very* quickly due to the lack of TCO in the engine. Needs to be optimized. - Existing type inference issues. --- Cargo.lock | 1 + optd-cli/Cargo.toml | 1 + optd-cli/examples/logical_rules.opt | 281 +++++++++++++++++++++++++++ optd-cli/src/main.rs | 137 ++++++++++++- optd/src/dsl/analyzer/hir/display.rs | 221 +++++++++++++++++++++ optd/src/dsl/analyzer/hir/mod.rs | 1 + 6 files changed, 637 insertions(+), 5 deletions(-) create mode 100644 optd-cli/examples/logical_rules.opt create mode 100644 optd/src/dsl/analyzer/hir/display.rs diff --git a/Cargo.lock b/Cargo.lock index 25bb15fc..a3ae18be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2020,6 +2020,7 @@ dependencies = [ "clap", "colored", "optd", + "tokio", ] [[package]] diff --git a/optd-cli/Cargo.toml b/optd-cli/Cargo.toml index 030f0d69..a5cae97e 100644 --- a/optd-cli/Cargo.toml +++ b/optd-cli/Cargo.toml @@ -9,3 +9,4 @@ optd = { path = "../optd" } clap = { version = "4.5.37", features = ["derive"] } colored = "3.0.0" +tokio = "1.44.2" diff --git a/optd-cli/examples/logical_rules.opt b/optd-cli/examples/logical_rules.opt new file mode 100644 index 00000000..67c15b3f --- /dev/null +++ b/optd-cli/examples/logical_rules.opt @@ -0,0 +1,281 @@ +data Physical +data PhysicalProperties +data Statistics +data LogicalProperties + +data Logical = + | Add(left: Logical, right: Logical) + | Sub(left: Logical, right: Logical) + | Mult(left: Logical, right: Logical) + | Div(left: Logical, right: Logical) + | Pow(base: Logical, exponent: Logical) + | Neg(expr: Logical) + | Max(left: Logical, right: Logical) + | Min(left: Logical, right: Logical) + \ Const(val: I64) + +fn recursive_power(base: I64, exp: I64): I64 = + if exp == 0 then + 1 + else if exp == 1 then + base + else + base * recursive_power(base, exp - 1) + +fn (op: Logical) evaluate(): I64 = match op + | Add(left, right) -> left.evaluate() + right.evaluate() + | Sub(left, right) -> left.evaluate() - right.evaluate() + | Mult(left, right) -> left.evaluate() * right.evaluate() + | Div(left, right) -> + let r = right.evaluate() in + if r == 0 then + fail("cannot div by zero") + else + left.evaluate() / r + | Pow(base, exp) -> + let + b = base.evaluate(), + e = exp.evaluate() + in + if e < 0 then fail("negative exponent") else recursive_power(b, e) + | Neg(expr) -> -expr.evaluate() + | Max(left, right) -> + let + l = left.evaluate(), + r = right.evaluate() + in + if l > r then l else r + | Min(left, right) -> + let + l = left.evaluate(), + r = right.evaluate() + in + if l < r then l else r + \ Const(val) -> val + +fn (op: Logical) mult_commute(): Logical? = match op + | Mult(left, right) -> Mult(right, left) + \ _ -> none + +fn (op: Logical) add_commute(): Logical? = match op + | Add(left, right) -> Add(right, left) + \ _ -> none + +fn (op: Logical) const_fold_add(): Logical? = match op + | Add(Const(a), Const(b)) -> Const(a + b) + \ _ -> none + +fn (op: Logical) const_fold_mult(): Logical? = match op + | Mult(Const(a), Const(b)) -> Const(a * b) + \ _ -> none + +fn (op: Logical) const_fold_sub(): Logical? = match op + | Sub(Const(a), Const(b)) -> Const(a - b) + \ _ -> none + +fn (op: Logical) const_fold_div(): Logical? = match op + | Div(Const(a), Const(b)) -> + if b == 0 then none else Const(a / b) + \ _ -> none + +fn (op: Logical) mult_by_zero(): Logical? = match op + | Mult(_, Const(0)) -> Const(0) + | Mult(Const(0), _) -> Const(0) + \ _ -> none + +fn (op: Logical) mult_by_one(): Logical? = match op + | Mult(expr, Const(1)) -> expr + | Mult(Const(1), expr) -> expr + \ _ -> none + +fn (op: Logical) add_by_zero(): Logical? = match op + | Add(expr, Const(0)) -> expr + | Add(Const(0), expr) -> expr + \ _ -> none + +fn (op: Logical) double_neg(): Logical? = match op + | Neg(Neg(expr)) -> expr + \ _ -> none + +fn (op: Logical) pow_zero(): Logical? = match op + | Pow(_, Const(0)) -> Const(1) + \ _ -> none + +fn (op: Logical) pow_one(): Logical? = match op + | Pow(base, Const(1)) -> base + \ _ -> none + +fn (op: Logical) distributive(): Logical? = match op + | Mult(factor, Add(left, right)) -> + Add(Mult(factor, left), Mult(factor, right)) + | Mult(Add(left, right), factor) -> + Add(Mult(left, factor), Mult(right, factor)) + \ _ -> none + +fn (op: Logical) sub_to_add(): Logical? = match op + | Sub(left, right) -> Add(left, Neg(right)) + \ _ -> none + +fn (op: Logical) same_minmax(): Logical? = match op + | Min(expr1, expr2) -> if expr1 == expr2 then expr1 else none + | Max(expr1, expr2) -> if expr1 == expr2 then expr1 else none + \ _ -> none + +fn (op: Logical) const_fold_minmax(): Logical? = match op + | Min(Const(a), Const(b)) -> Const(if a < b then a else b) + | Max(Const(a), Const(b)) -> Const(if a > b then a else b) + \ _ -> none + +fn (op: Logical) nested_minmax(): Logical? = match op + | Min(Min(a, b), c) -> Min(a, Min(b, c)) + | Max(Max(a, b), c) -> Max(a, Max(b, c)) + \ _ -> none + +[run] +fn build_calculator_expr() = + let + const2 = Const(2), + const3 = Const(3), + const4 = Const(4), + addition = Add(const2, const3), + multiplication = Mult(addition, const4) + in + multiplication.evaluate() + +[run] +fn run_mult_commute() = + let + const5 = Const(5), + const10 = Const(10), + mult = Mult(const5, const10) + in + mult_commute(mult) + +[run] +fn run_add_commute() = + let + const7 = Const(7), + const12 = Const(12), + addition = Add(const7, const12) + in + add_commute(addition) + +[run] +fn run_const_fold_add() = + let + const5 = Const(5), + const8 = Const(8), + addition = Add(const5, const8) + in + const_fold_add(addition) + +[run] +fn run_const_fold_mult() = + let + const6 = Const(6), + const9 = Const(9), + multiplication = Mult(const6, const9) + in + const_fold_mult(multiplication) + +[run] +fn run_mult_by_zero() = + let + const0 = Const(0), + const42 = Const(42), + multiplication = Mult(const42, const0) + in + mult_by_zero(multiplication) + +[run] +fn run_mult_by_one() = + let + const1 = Const(1), + const25 = Const(25), + multiplication = Mult(const1, const25) + in + mult_by_one(multiplication) + +[run] +fn run_add_by_zero() = + let + const0 = Const(0), + const17 = Const(17), + addition = Add(const17, const0) + in + add_by_zero(addition) + +[run] +fn run_double_neg() = + let + const13 = Const(13), + neg1 = Neg(const13), + neg2 = Neg(neg1) + in + double_neg(neg2) + +[run] +fn run_distributive() = + let + const2 = Const(2), + const3 = Const(3), + const4 = Const(4), + addition = Add(const3, const4), + multiplication = Mult(const2, addition) + in + distributive(multiplication) + +[run] +fn run_sub_to_add() = + let + const8 = Const(8), + const3 = Const(3), + subtraction = Sub(const8, const3) + in + sub_to_add(subtraction) + +// [run] +// fn run_same_minmax() = +// let +// const5 = Const(5), +// min_op = Min(const5, const5) +// in +// same_minmax(min_op) + +[run] +fn run_const_fold_minmax() = + let + const8 = Const(8), + const15 = Const(15), + min_op = Min(const8, const15) + in + const_fold_minmax(min_op) + +[run] +fn run_nested_minmax() = + let + const3 = Const(3), + const6 = Const(6), + const9 = Const(9), + inner_min = Min(const3, const6), + outer_min = Min(inner_min, const9) + in + nested_minmax(outer_min) + +[run] +fn run_pow_zero() = + let + const7 = Const(7), + const0 = Const(0), + pow_zero_expr = Pow(const7, const0) + in + pow_zero(pow_zero_expr) + +[run] +fn run_division_by_zero() = + let + const5 = Const(5), + const0 = Const(0), + division = Div(const5, const0) + in + division.evaluate() \ No newline at end of file diff --git a/optd-cli/src/main.rs b/optd-cli/src/main.rs index d3e0a6cf..85622134 100644 --- a/optd-cli/src/main.rs +++ b/optd-cli/src/main.rs @@ -17,6 +17,9 @@ //! # Get help: //! optd-cli --help //! optd-cli compile --help +//! +//! # Run functions marked with [run] annotation: +//! optd-cli run-functions path/to/file.opt //! ``` //! //! When developing, you can run through cargo: @@ -26,15 +29,21 @@ //! cargo run --bin optd-cli -- compile path/to/example.opt --verbose //! cargo run --bin optd-cli -- compile path/to/example.opt --verbose --show-ast --show-hir //! cargo run --bin optd-cli -- compile path/to/example.opt --mock-udfs hello get_schema world +//! cargo run --bin optd-cli -- run-functions path/to/example.opt //! ``` use clap::{Parser, Subcommand}; use colored::Colorize; use optd::catalog::Catalog; -use optd::dsl::analyzer::hir::{CoreData, Udf, Value}; +use optd::catalog::iceberg::memory_catalog; +use optd::dsl::analyzer::hir::{CoreData, HIR, Udf, Value}; use optd::dsl::compile::{Config, compile_hir}; +use optd::dsl::engine::{Continuation, Engine, EngineResponse}; use optd::dsl::utils::errors::{CompileError, Diagnose}; use std::collections::HashMap; +use std::sync::Arc; +use tokio::runtime::Runtime; +use tokio::task::JoinSet; #[derive(Parser)] #[command( @@ -52,6 +61,8 @@ struct Cli { enum Commands { /// Compile a DSL file (parse and analyze). Compile(Config), + /// Run functions annotated with [run]. + RunFunctions(Config), } /// A unimplemented user-defined function. @@ -69,17 +80,133 @@ fn main() -> Result<(), Vec> { }; udfs.insert("unimplemented_udf".to_string(), udf.clone()); - let Commands::Compile(config) = cli.command; + match cli.command { + Commands::Compile(config) => { + for mock_udf in config.mock_udfs() { + udfs.insert(mock_udf.to_string(), udf.clone()); + } + + let _ = compile_hir(config, udfs).unwrap_or_else(|errors| handle_errors(&errors)); + Ok(()) + } + Commands::RunFunctions(config) => { + // TODO(Connor): Add support for running functions with real UDFs. + for mock_udf in config.mock_udfs() { + udfs.insert(mock_udf.to_string(), udf.clone()); + } + + let hir = compile_hir(config, udfs).unwrap_or_else(|errors| handle_errors(&errors)); - for mock_udf in config.mock_udfs() { - udfs.insert(mock_udf.to_string(), udf.clone()); + run_all_functions(&hir) + } } +} - let _hir = compile_hir(config, udfs).unwrap_or_else(|errors| handle_errors(&errors)); +/// Result of running a function. +struct FunctionResult { + name: String, + result: EngineResponse, +} + +/// Run all functions found in the HIR, marked with [run]. +fn run_all_functions(hir: &HIR) -> Result<(), Vec> { + println!("\n{} {}\n", "•".green(), "Running functions...".green()); + + let functions = find_functions(hir); + + if functions.is_empty() { + println!("No functions found annotated with [run]"); + return Ok(()); + } + + println!("Found {} functions to run", functions.len()); + + // Create a multi-threaded runtime for parallel execution. + let runtime = Runtime::new().unwrap(); + let function_results = runtime.block_on(run_functions_in_parallel(hir, functions)); + + // Process and display function results. + let success_count = process_function_results(function_results); + + println!( + "\n{} {}", + "Execution Results:".yellow(), + format!("{} functions executed", success_count).yellow() + ); Ok(()) } +async fn run_functions_in_parallel(hir: &HIR, functions: Vec) -> Vec { + let catalog = Arc::new(memory_catalog()); + let mut set = JoinSet::new(); + + for function_name in functions { + let engine = Engine::new(hir.context.clone(), catalog.clone()); + let name = function_name.clone(); + + set.spawn(async move { + // Create a continuation that returns itself. + let result_handler: Continuation = + Arc::new(|value| Box::pin(async move { value })); + + // Launch the function with an empty vector of arguments. + let result = engine.launch_rule(&name, vec![], result_handler).await; + FunctionResult { name, result } + }); + } + + // Collect all function results. + let mut results = Vec::new(); + while let Some(result) = set.join_next().await { + if let Ok(function_result) = result { + results.push(function_result); + } + } + + results +} + +/// Process function results and display them. +fn process_function_results(function_results: Vec) -> usize { + let mut success_count = 0; + + for function_result in function_results { + println!("\n{} {}", "Function:".blue(), function_result.name); + + match function_result.result { + EngineResponse::Return(value, _) => { + // Check if the result is a failure. + if matches!(value.data, CoreData::Fail(_)) { + println!(" {}: Function failed: {}", "Error".red(), value); + } else { + println!(" {}: {}", "Result".green(), value); + success_count += 1; + } + } + _ => unreachable!(), // For now, unless we add a special UDF that builds a group / goal. + } + } + + success_count +} + +/// Find functions with the [run] annotation. +fn find_functions(hir: &HIR) -> Vec { + let mut functions = Vec::new(); + + for (name, _) in hir.context.get_all_bindings() { + if let Some(annotations) = hir.annotations.get(name) { + if annotations.iter().any(|a| a == "run") { + functions.push(name.clone()); + } + } + } + + functions +} + +/// Display error details and exit the program. fn handle_errors(errors: &[CompileError]) -> ! { eprintln!( "\n{} {}\n", diff --git a/optd/src/dsl/analyzer/hir/display.rs b/optd/src/dsl/analyzer/hir/display.rs new file mode 100644 index 00000000..7a79d90a --- /dev/null +++ b/optd/src/dsl/analyzer/hir/display.rs @@ -0,0 +1,221 @@ +//! Display implementations for HIR types. +//! +//! This module provides Display trait implementations for various HIR types, +//! enabling human-readable formatting of expressions, values, and patterns. + +use super::map::Map; +use super::{ + BinOp, CoreData, FunKind, Goal, GroupId, Literal, LogicalOp, Materializable, NoMetadata, + PhysicalOp, UnaryOp, Value, +}; +use std::fmt; + +impl fmt::Display for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.data { + CoreData::Literal(lit) => match lit { + Literal::Int64(i) => write!(f, "{}", i), + Literal::Float64(fl) => write!(f, "{}", fl), + Literal::String(s) => write!(f, "\"{}\"", s), + Literal::Bool(b) => write!(f, "{}", b), + Literal::Unit => write!(f, "()"), + }, + CoreData::Array(items) => { + write!(f, "[")?; + for (i, item) in items.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", item)?; + } + write!(f, "]") + } + CoreData::Tuple(items) => { + write!(f, "(")?; + for (i, item) in items.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", item)?; + } + if items.len() == 1 { + write!(f, ",")?; // Add trailing comma for 1-tuples + } + write!(f, ")") + } + CoreData::Map(map) => format_map(f, map), + CoreData::Struct(name, fields) => { + write!(f, "{}(", name)?; + for (i, field) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", field)?; + } + write!(f, ")") + } + CoreData::Function(fun_kind) => match fun_kind { + FunKind::Closure(params, _) => { + write!(f, "λ(")?; + for (i, param) in params.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", param)?; + } + write!(f, ") → ‹code›") + } + FunKind::Udf(_) => write!(f, "udf@native"), + }, + CoreData::Fail(value) => write!(f, "fail({})", value), + CoreData::Logical(materializable) => match materializable { + Materializable::Materialized(logical_op) => format_logical_op(f, logical_op), + Materializable::UnMaterialized(group_id) => { + write!(f, "group({})", group_id.0) + } + }, + CoreData::Physical(materializable) => match materializable { + Materializable::Materialized(physical_op) => format_physical_op(f, physical_op), + Materializable::UnMaterialized(goal) => format_goal(f, goal), + }, + CoreData::None => write!(f, "none"), + } + } +} + +// Helper function to format Map contents +fn format_map(f: &mut fmt::Formatter<'_>, map: &Map) -> fmt::Result { + write!(f, "{{")?; + let mut first = true; + + // Sort keys for deterministic output + let mut entries: Vec<_> = map.inner.iter().collect(); + entries.sort_by(|a, b| format!("{:?}", a.0).cmp(&format!("{:?}", b.0))); + + for (key, value) in entries { + if !first { + write!(f, ", ")?; + } + first = false; + + // Format the key and value + write!(f, "{:?} → {}", key, value)?; + } + write!(f, "}}") +} + +// Helper function to format logical operators +fn format_logical_op(f: &mut fmt::Formatter<'_>, op: &LogicalOp>) -> fmt::Result { + write!(f, "{}(", op.operator.tag)?; + + // Format data + for (i, data) in op.operator.data.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", data)?; + } + + // Format children if any + if !op.operator.children.is_empty() { + if !op.operator.data.is_empty() { + write!(f, "; ")?; + } + for (i, child) in op.operator.children.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", child)?; + } + } + + // Add group_id if present + if let Some(group_id) = op.group_id { + write!(f, ")[{}]", group_id.0) + } else { + write!(f, ")") + } +} + +// Helper function to format physical operators +fn format_physical_op( + f: &mut fmt::Formatter<'_>, + op: &PhysicalOp, NoMetadata>, +) -> fmt::Result { + write!(f, "PhysOp:{}(", op.operator.tag)?; + + // Format data + for (i, data) in op.operator.data.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", data)?; + } + + // Format children if any + if !op.operator.children.is_empty() { + if !op.operator.data.is_empty() { + write!(f, "; ")?; + } + for (i, child) in op.operator.children.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", child)?; + } + } + + write!(f, ")")?; + + // Add goal if present + if let Some(goal) = &op.goal { + write!(f, "[goal: {}]", goal.group_id.0)?; + } + + // Add cost if present + if let Some(cost) = &op.cost { + write!(f, "[cost: {}]", cost)?; + } + + Ok(()) +} + +// Helper function to format goals +fn format_goal(f: &mut fmt::Formatter<'_>, goal: &Goal) -> fmt::Result { + write!(f, "Goal[{}]{{{}}}", goal.group_id.0, goal.properties) +} + +// For completeness, let's add Display implementations for binary and unary operators + +impl fmt::Display for BinOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BinOp::Add => write!(f, "+"), + BinOp::Sub => write!(f, "-"), + BinOp::Mul => write!(f, "*"), + BinOp::Div => write!(f, "/"), + BinOp::Lt => write!(f, "<"), + BinOp::Eq => write!(f, "=="), + BinOp::And => write!(f, "&&"), + BinOp::Or => write!(f, "||"), + BinOp::Range => write!(f, ".."), + BinOp::Concat => write!(f, "++"), + } + } +} + +impl fmt::Display for UnaryOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + UnaryOp::Neg => write!(f, "-"), + UnaryOp::Not => write!(f, "!"), + } + } +} + +// Add Display implementation for GroupId for convenience +impl fmt::Display for GroupId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "G{}", self.0) + } +} diff --git a/optd/src/dsl/analyzer/hir/mod.rs b/optd/src/dsl/analyzer/hir/mod.rs index f61949d1..29b3ec08 100644 --- a/optd/src/dsl/analyzer/hir/mod.rs +++ b/optd/src/dsl/analyzer/hir/mod.rs @@ -24,6 +24,7 @@ use std::fmt::Debug; use std::{collections::HashMap, sync::Arc}; pub(crate) mod context; +mod display; pub(crate) mod map; /// Unique identifier for variables, functions, types, etc. From d0ec48c89c9b6cb79b45425332c24b897ae013d1 Mon Sep 17 00:00:00 2001 From: AlSchlo <79570602+AlSchlo@users.noreply.github.com> Date: Wed, 7 May 2025 00:53:57 +0200 Subject: [PATCH 8/8] (feat) Add Run Annotation Check (#100) As title, also refactored the annotation checking process. --- optd-cli/examples/tutorial.opt | 5 +- optd/src/dsl/analyzer/errors.rs | 80 ++++++------------- optd/src/dsl/analyzer/into_hir/annotations.rs | 64 ++++++++++++++- 3 files changed, 87 insertions(+), 62 deletions(-) diff --git a/optd-cli/examples/tutorial.opt b/optd-cli/examples/tutorial.opt index 14bca04b..38474976 100644 --- a/optd-cli/examples/tutorial.opt +++ b/optd-cli/examples/tutorial.opt @@ -184,7 +184,6 @@ data Physical = data Catalog data LogicalProperties(schema: Schema) data PhysicalProperties(order_by: [Bool]) // e.g., define sorting, partitioning, etc. -data CostedProperties // e.g., define cost model, statistics, etc. data Statistics // e.g., define histograms, MCVs, etc. data Schema(columns: [Column]) @@ -196,7 +195,7 @@ data Column(name: String, data_type: String, is_nullable: Bool) // Functions can be defined in two equivalent ways: -fn example1(arg1: Logical, arg2: I64): I64 = 5 // Call with: example1(arg1, arg2) +fn example1(arg1: Logical, arg2: I64) = 5 // Call with: example1(arg1, arg2) // Or as a member function: @@ -389,7 +388,7 @@ fn (expr: Physical*) cost(): F64 = 0 // through the logical plan. fn (log: Logical*) derive(): LogicalProperties = match log - | Get(table_name) -> Catalog.get_table_schema(table_name) + | Get(table_name) -> catalog.get_table_schema(table_name) | Filter(child, _) -> child.properties() | Join(left, right, join_type, _) -> let diff --git a/optd/src/dsl/analyzer/errors.rs b/optd/src/dsl/analyzer/errors.rs index 5554c2a1..37c887a9 100644 --- a/optd/src/dsl/analyzer/errors.rs +++ b/optd/src/dsl/analyzer/errors.rs @@ -1,9 +1,6 @@ use super::{hir::Identifier, type_checks::registry::Type}; use crate::dsl::{ - analyzer::{ - into_hir::annotations::{IMPLEMENTATION_SIGNATURE_TYPE, TRANSFORMATION_SIGNATURE_TYPE}, - type_checks::converter::type_display, - }, + analyzer::type_checks::converter::type_display, utils::{ errors::Diagnose, span::{Span, Spanned}, @@ -126,16 +123,11 @@ pub enum AnalyzerErrorKind { unknowns: HashMap, }, - InvalidTransformation { + InvalidAnnotation { span: Span, - actual_signature: Type, - // To be able to call display function of Type. - unknowns: HashMap, - }, - - InvalidImplementation { - span: Span, - actual_signature: Type, + annotation_name: String, + actual_type: Type, + expected_type: Type, // To be able to call display function of Type. unknowns: HashMap, }, @@ -318,27 +310,18 @@ impl AnalyzerErrorKind { .into() } - pub fn new_invalid_transformation( + pub fn new_invalid_annotation( span: &Span, - actual_signature: &Type, + annotation_name: &str, + actual_type: &Type, + expected_type: &Type, unknowns: HashMap, ) -> Box { - Self::InvalidTransformation { + Self::InvalidAnnotation { span: span.clone(), - actual_signature: actual_signature.clone(), - unknowns, - } - .into() - } - - pub fn new_invalid_implementation( - span: &Span, - actual_signature: &Type, - unknowns: HashMap, - ) -> Box { - Self::InvalidImplementation { - span: span.clone(), - actual_signature: actual_signature.clone(), + annotation_name: annotation_name.to_string(), + actual_type: actual_type.clone(), + expected_type: expected_type.clone(), unknowns, } .into() @@ -480,39 +463,23 @@ impl Diagnose for Box { "Only functions, maps, and arrays can be called", ), InvalidFieldAccess { object, span, field, unknowns } => self.build_invalid_field_access_report(object, span, field, unknowns), - InvalidTransformation { + InvalidAnnotation { span, - actual_signature, + annotation_name, + actual_type, + expected_type, unknowns, } => { self.build_single_span_report( span, - "Invalid transformation function signature", + &format!("Invalid '{}' function signature", annotation_name), &format!( "Found: '{}'", - type_display(actual_signature, unknowns) + type_display(actual_type, unknowns) ), &format!( "Expected a subtype of: '{}'", - type_display(&TRANSFORMATION_SIGNATURE_TYPE, unknowns) - ), - ) - }, - InvalidImplementation { - span, - actual_signature, - unknowns, - } => { - self.build_single_span_report( - span, - "Invalid implementation function signature", - &format!( - "Found: '{}'", - type_display(actual_signature, unknowns) - ), - &format!( - "Expected a subtype of: '{}'", - type_display(&IMPLEMENTATION_SIGNATURE_TYPE, unknowns) + type_display(expected_type, unknowns) ), ) }, @@ -551,10 +518,9 @@ impl Diagnose for Box { ArgumentNumberMismatch { span, .. } => span, InvalidCallReceiver { span, .. } => span, InvalidFieldAccess { span, .. } => span, - InvalidTransformation { span, .. } => span, - InvalidImplementation { span, .. } => span, - InvalidArrayDecomposition { pattern_span, .. } => pattern_span, // Use pattern span as primary - ReservedType { span, .. } => span, // New case for ReservedType + InvalidAnnotation { span, .. } => span, + InvalidArrayDecomposition { pattern_span, .. } => pattern_span, + ReservedType { span, .. } => span, }; (span.src_file.clone(), Source::from(self.src_code.clone())) diff --git a/optd/src/dsl/analyzer/into_hir/annotations.rs b/optd/src/dsl/analyzer/into_hir/annotations.rs index 24a06ca5..e1516a52 100644 --- a/optd/src/dsl/analyzer/into_hir/annotations.rs +++ b/optd/src/dsl/analyzer/into_hir/annotations.rs @@ -27,6 +27,14 @@ pub static IMPLEMENTATION_SIGNATURE_TYPE: Lazy = Lazy::new(|| { pub const IMPLEMENTATION_ANNOTATION: &str = "implementation"; +pub static RUN_SIGNATURE_TYPE: Lazy = Lazy::new(|| { + use TypeKind::*; + let param_type = Unit.into(); + Closure(param_type, Universe.into()).into() +}); + +pub const RUN_ANNOTATION: &str = "run"; + /// Validates that a function's type matches the expected signature for a given annotation /// /// # Arguments @@ -48,18 +56,33 @@ pub(super) fn validate_annotation( match annotation { TRANSFORMATION_ANNOTATION => { if !registry.is_subtype(function_type, &TRANSFORMATION_SIGNATURE_TYPE) { - return Err(AnalyzerErrorKind::new_invalid_transformation( + return Err(AnalyzerErrorKind::new_invalid_annotation( function_span, + TRANSFORMATION_ANNOTATION, function_type, + &TRANSFORMATION_SIGNATURE_TYPE, registry.resolved_unknown.clone(), )); } } IMPLEMENTATION_ANNOTATION => { if !registry.is_subtype(function_type, &IMPLEMENTATION_SIGNATURE_TYPE) { - return Err(AnalyzerErrorKind::new_invalid_implementation( + return Err(AnalyzerErrorKind::new_invalid_annotation( + function_span, + IMPLEMENTATION_ANNOTATION, + function_type, + &IMPLEMENTATION_SIGNATURE_TYPE, + registry.resolved_unknown.clone(), + )); + } + } + RUN_ANNOTATION => { + if !registry.is_subtype(function_type, &RUN_SIGNATURE_TYPE) { + return Err(AnalyzerErrorKind::new_invalid_annotation( function_span, + RUN_ANNOTATION, function_type, + &RUN_SIGNATURE_TYPE, registry.resolved_unknown.clone(), )); } @@ -255,6 +278,43 @@ mod tests { assert!(result.is_err()); } + #[test] + fn test_validate_run_annotation_success() { + let mut registry = TypeRegistry::new(); + + // Create a valid run function type + let function_type = + TypeKind::Closure(TypeKind::Unit.into(), TypeKind::Universe.into()).into(); + + let result = validate_annotation( + RUN_ANNOTATION, + &function_type, + &create_test_span(), + &mut registry, + ); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_run_annotation_failure_wrong_param() { + let mut registry = TypeRegistry::new(); + + // Create an invalid function type (wrong parameter type) + let function_type = TypeKind::Closure( + TypeKind::I64.into(), // Wrong parameter type, should be Unit + TypeKind::Universe.into(), + ) + .into(); + + let result = validate_annotation( + RUN_ANNOTATION, + &function_type, + &create_test_span(), + &mut registry, + ); + assert!(result.is_err()); + } + #[test] fn test_validate_unknown_annotation() { let mut registry = TypeRegistry::new();