Skip to content

Commit e5f596b

Browse files
Fix: handle column name collisions when combining UNION logical inputs & nested Column expressions in maybe_fix_physical_column_name (#16064)
* Fix union schema name coercion * Address renaming for columns that are not in the top level as well * Add unit test * Format * Use insta tests properly * Address review - comment + minor simplification change --------- Co-authored-by: Berkay Şahin <[email protected]>
1 parent 67a2173 commit e5f596b

File tree

4 files changed

+426
-13
lines changed

4 files changed

+426
-13
lines changed

datafusion/core/src/physical_planner.rs

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch};
6262
use arrow::compute::SortOptions;
6363
use arrow::datatypes::{Schema, SchemaRef};
6464
use datafusion_common::display::ToStringifiedPlan;
65-
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
65+
use datafusion_common::tree_node::{
66+
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
67+
};
6668
use datafusion_common::{
6769
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema,
6870
ScalarValue,
@@ -2075,29 +2077,36 @@ fn maybe_fix_physical_column_name(
20752077
expr: Result<Arc<dyn PhysicalExpr>>,
20762078
input_physical_schema: &SchemaRef,
20772079
) -> Result<Arc<dyn PhysicalExpr>> {
2078-
if let Ok(e) = &expr {
2079-
if let Some(column) = e.as_any().downcast_ref::<Column>() {
2080-
let physical_field = input_physical_schema.field(column.index());
2080+
let Ok(expr) = expr else { return expr };
2081+
expr.transform_down(|node| {
2082+
if let Some(column) = node.as_any().downcast_ref::<Column>() {
2083+
let idx = column.index();
2084+
let physical_field = input_physical_schema.field(idx);
20812085
let expr_col_name = column.name();
20822086
let physical_name = physical_field.name();
20832087

2084-
if physical_name != expr_col_name {
2088+
if expr_col_name != physical_name {
20852089
// handle edge cases where the physical_name contains ':'.
20862090
let colon_count = physical_name.matches(':').count();
20872091
let mut splits = expr_col_name.match_indices(':');
20882092
let split_pos = splits.nth(colon_count);
20892093

2090-
if let Some((idx, _)) = split_pos {
2091-
let base_name = &expr_col_name[..idx];
2094+
if let Some((i, _)) = split_pos {
2095+
let base_name = &expr_col_name[..i];
20922096
if base_name == physical_name {
2093-
let updated_column = Column::new(physical_name, column.index());
2094-
return Ok(Arc::new(updated_column));
2097+
let updated_column = Column::new(physical_name, idx);
2098+
return Ok(Transformed::yes(Arc::new(updated_column)));
20952099
}
20962100
}
20972101
}
2102+
2103+
// If names already match or fix is not possible, just leave it as it is
2104+
Ok(Transformed::no(node))
2105+
} else {
2106+
Ok(Transformed::no(node))
20982107
}
2099-
}
2100-
expr
2108+
})
2109+
.data()
21012110
}
21022111

21032112
struct OptimizationInvariantChecker<'a> {
@@ -2203,8 +2212,11 @@ mod tests {
22032212
};
22042213
use datafusion_execution::runtime_env::RuntimeEnv;
22052214
use datafusion_execution::TaskContext;
2206-
use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore};
2215+
use datafusion_expr::{
2216+
col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore,
2217+
};
22072218
use datafusion_functions_aggregate::expr_fn::sum;
2219+
use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr};
22082220
use datafusion_physical_expr::EquivalenceProperties;
22092221
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
22102222

@@ -2769,6 +2781,47 @@ mod tests {
27692781

27702782
assert_eq!(col.name(), "metric:avg");
27712783
}
2784+
2785+
#[tokio::test]
2786+
async fn test_maybe_fix_nested_column_name_with_colon() {
2787+
let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]);
2788+
let schema_ref: SchemaRef = Arc::new(schema);
2789+
2790+
// Construct the nested expr
2791+
let col_expr = Arc::new(Column::new("column:1", 0)) as Arc<dyn PhysicalExpr>;
2792+
let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone()));
2793+
2794+
// Create a binary expression and put the column inside
2795+
let binary_expr = Arc::new(BinaryExpr::new(
2796+
is_not_null_expr.clone(),
2797+
Operator::Or,
2798+
is_not_null_expr.clone(),
2799+
)) as Arc<dyn PhysicalExpr>;
2800+
2801+
let fixed_expr =
2802+
maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap();
2803+
2804+
let bin = fixed_expr
2805+
.as_any()
2806+
.downcast_ref::<BinaryExpr>()
2807+
.expect("Expected BinaryExpr");
2808+
2809+
// Check that both sides where renamed
2810+
for expr in &[bin.left(), bin.right()] {
2811+
let is_not_null = expr
2812+
.as_any()
2813+
.downcast_ref::<IsNotNullExpr>()
2814+
.expect("Expected IsNotNull");
2815+
2816+
let col = is_not_null
2817+
.arg()
2818+
.as_any()
2819+
.downcast_ref::<Column>()
2820+
.expect("Expected Column");
2821+
2822+
assert_eq!(col.name(), "column");
2823+
}
2824+
}
27722825
struct ErrorExtensionPlanner {}
27732826

27742827
#[async_trait]

datafusion/physical-plan/src/union.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,12 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
540540

541541
let fields = (0..first_schema.fields().len())
542542
.map(|i| {
543-
inputs
543+
// We take the name from the left side of the union to match how names are coerced during logical planning,
544+
// which also uses the left side names.
545+
let base_field = first_schema.field(i).clone();
546+
547+
// Coerce metadata and nullability across all inputs
548+
let merged_field = inputs
544549
.iter()
545550
.enumerate()
546551
.map(|(input_idx, input)| {
@@ -562,6 +567,9 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
562567
// We can unwrap this because if inputs was empty, this would've already panic'ed when we
563568
// indexed into inputs[0].
564569
.unwrap()
570+
.with_name(base_field.name());
571+
572+
merged_field
565573
})
566574
.collect::<Vec<_>>();
567575

datafusion/substrait/tests/cases/consumer_integration.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,4 +560,28 @@ mod tests {
560560
);
561561
Ok(())
562562
}
563+
564+
#[tokio::test]
565+
async fn test_multiple_unions() -> Result<()> {
566+
let plan_str = test_plan_to_string("multiple_unions.json").await?;
567+
assert_snapshot!(
568+
plan_str,
569+
@r#"
570+
Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key
571+
Union
572+
Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key
573+
Left Join: sales.product_key = food.@food_id
574+
TableScan: sales
575+
TableScan: food
576+
Union
577+
Projection: people.$f3, people.$f5, people.product_key0
578+
Left Join: people.product_key0 = food.@food_id
579+
TableScan: people
580+
TableScan: food
581+
TableScan: more_products
582+
"#
583+
);
584+
585+
Ok(())
586+
}
563587
}

0 commit comments

Comments
 (0)