Skip to content

Fix: handle column name collisions when combining UNION logical inputs & nested Column expressions in maybe_fix_physical_column_name #16064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
77 changes: 65 additions & 12 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch};
use arrow::compute::SortOptions;
use arrow::datatypes::{Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema,
ScalarValue,
Expand Down Expand Up @@ -2075,29 +2077,36 @@ fn maybe_fix_physical_column_name(
expr: Result<Arc<dyn PhysicalExpr>>,
input_physical_schema: &SchemaRef,
) -> Result<Arc<dyn PhysicalExpr>> {
if let Ok(e) = &expr {
if let Some(column) = e.as_any().downcast_ref::<Column>() {
let physical_field = input_physical_schema.field(column.index());
let Ok(expr) = expr else { return expr };
expr.transform_down(|node| {
if let Some(column) = node.as_any().downcast_ref::<Column>() {
let idx = column.index();
let physical_field = input_physical_schema.field(idx);
let expr_col_name = column.name();
let physical_name = physical_field.name();

if physical_name != expr_col_name {
if expr_col_name != physical_name {
// handle edge cases where the physical_name contains ':'.
let colon_count = physical_name.matches(':').count();
let mut splits = expr_col_name.match_indices(':');
let split_pos = splits.nth(colon_count);

if let Some((idx, _)) = split_pos {
let base_name = &expr_col_name[..idx];
if let Some((i, _)) = split_pos {
let base_name = &expr_col_name[..i];
if base_name == physical_name {
let updated_column = Column::new(physical_name, column.index());
return Ok(Arc::new(updated_column));
let updated_column = Column::new(physical_name, idx);
return Ok(Transformed::yes(Arc::new(updated_column)));
}
}
}

// If names already match or fix is not possible, just leave it as it is
Ok(Transformed::no(node))
} else {
Ok(Transformed::no(node))
}
}
expr
})
.data()
}

struct OptimizationInvariantChecker<'a> {
Expand Down Expand Up @@ -2203,8 +2212,11 @@ mod tests {
};
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore};
use datafusion_expr::{
col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore,
};
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr};
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};

Expand Down Expand Up @@ -2769,6 +2781,47 @@ mod tests {

assert_eq!(col.name(), "metric:avg");
}

#[tokio::test]
async fn test_maybe_fix_nested_column_name_with_colon() {
let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]);
let schema_ref: SchemaRef = Arc::new(schema);

// Construct the nested expr
let col_expr = Arc::new(Column::new("column:1", 0)) as Arc<dyn PhysicalExpr>;
let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone()));

// Create a binary expression and put the column inside
let binary_expr = Arc::new(BinaryExpr::new(
is_not_null_expr.clone(),
Operator::Or,
is_not_null_expr.clone(),
)) as Arc<dyn PhysicalExpr>;

let fixed_expr =
maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap();

let bin = fixed_expr
.as_any()
.downcast_ref::<BinaryExpr>()
.expect("Expected BinaryExpr");

// Check that both sides where renamed
for expr in &[bin.left(), bin.right()] {
let is_not_null = expr
.as_any()
.downcast_ref::<IsNotNullExpr>()
.expect("Expected IsNotNull");

let col = is_not_null
.arg()
.as_any()
.downcast_ref::<Column>()
.expect("Expected Column");

assert_eq!(col.name(), "column");
}
}
struct ErrorExtensionPlanner {}

#[async_trait]
Expand Down
10 changes: 9 additions & 1 deletion datafusion/physical-plan/src/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,12 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {

let fields = (0..first_schema.fields().len())
.map(|i| {
inputs
// We take the name from the left side of the union to match how names are coerced during logical planning,
// which also uses the left side names.
let base_field = first_schema.field(i).clone();

// Coerce metadata and nullability across all inputs
let merged_field = inputs
.iter()
.enumerate()
.map(|(input_idx, input)| {
Expand All @@ -562,6 +567,9 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
// We can unwrap this because if inputs was empty, this would've already panic'ed when we
// indexed into inputs[0].
.unwrap()
.with_name(base_field.name());

merged_field
})
.collect::<Vec<_>>();

Expand Down
24 changes: 24 additions & 0 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,4 +560,28 @@ mod tests {
);
Ok(())
}

#[tokio::test]
async fn test_multiple_unions() -> Result<()> {
let plan_str = test_plan_to_string("multiple_unions.json").await?;
assert_snapshot!(
plan_str,
@r#"
Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key
Union
Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key
Left Join: sales.product_key = food.@food_id
TableScan: sales
TableScan: food
Union
Projection: people.$f3, people.$f5, people.product_key0
Left Join: people.product_key0 = food.@food_id
TableScan: people
TableScan: food
TableScan: more_products
"#
);

Ok(())
}
}
Loading