@@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch};
62
62
use arrow:: compute:: SortOptions ;
63
63
use arrow:: datatypes:: { Schema , SchemaRef } ;
64
64
use datafusion_common:: display:: ToStringifiedPlan ;
65
- use datafusion_common:: tree_node:: { TreeNode , TreeNodeRecursion , TreeNodeVisitor } ;
65
+ use datafusion_common:: tree_node:: {
66
+ Transformed , TransformedResult , TreeNode , TreeNodeRecursion , TreeNodeVisitor ,
67
+ } ;
66
68
use datafusion_common:: {
67
69
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema ,
68
70
ScalarValue ,
@@ -2075,29 +2077,36 @@ fn maybe_fix_physical_column_name(
2075
2077
expr : Result < Arc < dyn PhysicalExpr > > ,
2076
2078
input_physical_schema : & SchemaRef ,
2077
2079
) -> Result < Arc < dyn PhysicalExpr > > {
2078
- if let Ok ( e) = & expr {
2079
- if let Some ( column) = e. as_any ( ) . downcast_ref :: < Column > ( ) {
2080
- let physical_field = input_physical_schema. field ( column. index ( ) ) ;
2080
+ let Ok ( expr) = expr else { return expr } ;
2081
+ expr. transform_down ( |node| {
2082
+ if let Some ( column) = node. as_any ( ) . downcast_ref :: < Column > ( ) {
2083
+ let idx = column. index ( ) ;
2084
+ let physical_field = input_physical_schema. field ( idx) ;
2081
2085
let expr_col_name = column. name ( ) ;
2082
2086
let physical_name = physical_field. name ( ) ;
2083
2087
2084
- if physical_name != expr_col_name {
2088
+ if expr_col_name != physical_name {
2085
2089
// handle edge cases where the physical_name contains ':'.
2086
2090
let colon_count = physical_name. matches ( ':' ) . count ( ) ;
2087
2091
let mut splits = expr_col_name. match_indices ( ':' ) ;
2088
2092
let split_pos = splits. nth ( colon_count) ;
2089
2093
2090
- if let Some ( ( idx , _) ) = split_pos {
2091
- let base_name = & expr_col_name[ ..idx ] ;
2094
+ if let Some ( ( i , _) ) = split_pos {
2095
+ let base_name = & expr_col_name[ ..i ] ;
2092
2096
if base_name == physical_name {
2093
- let updated_column = Column :: new ( physical_name, column . index ( ) ) ;
2094
- return Ok ( Arc :: new ( updated_column) ) ;
2097
+ let updated_column = Column :: new ( physical_name, idx ) ;
2098
+ return Ok ( Transformed :: yes ( Arc :: new ( updated_column) ) ) ;
2095
2099
}
2096
2100
}
2097
2101
}
2102
+
2103
+ // If names already match or fix is not possible, just leave it as it is
2104
+ Ok ( Transformed :: no ( node) )
2105
+ } else {
2106
+ Ok ( Transformed :: no ( node) )
2098
2107
}
2099
- }
2100
- expr
2108
+ } )
2109
+ . data ( )
2101
2110
}
2102
2111
2103
2112
struct OptimizationInvariantChecker < ' a > {
@@ -2203,8 +2212,11 @@ mod tests {
2203
2212
} ;
2204
2213
use datafusion_execution:: runtime_env:: RuntimeEnv ;
2205
2214
use datafusion_execution:: TaskContext ;
2206
- use datafusion_expr:: { col, lit, LogicalPlanBuilder , UserDefinedLogicalNodeCore } ;
2215
+ use datafusion_expr:: {
2216
+ col, lit, LogicalPlanBuilder , Operator , UserDefinedLogicalNodeCore ,
2217
+ } ;
2207
2218
use datafusion_functions_aggregate:: expr_fn:: sum;
2219
+ use datafusion_physical_expr:: expressions:: { BinaryExpr , IsNotNullExpr } ;
2208
2220
use datafusion_physical_expr:: EquivalenceProperties ;
2209
2221
use datafusion_physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
2210
2222
@@ -2769,6 +2781,47 @@ mod tests {
2769
2781
2770
2782
assert_eq ! ( col. name( ) , "metric:avg" ) ;
2771
2783
}
2784
+
2785
+ #[ tokio:: test]
2786
+ async fn test_maybe_fix_nested_column_name_with_colon ( ) {
2787
+ let schema = Schema :: new ( vec ! [ Field :: new( "column" , DataType :: Int32 , false ) ] ) ;
2788
+ let schema_ref: SchemaRef = Arc :: new ( schema) ;
2789
+
2790
+ // Construct the nested expr
2791
+ let col_expr = Arc :: new ( Column :: new ( "column:1" , 0 ) ) as Arc < dyn PhysicalExpr > ;
2792
+ let is_not_null_expr = Arc :: new ( IsNotNullExpr :: new ( col_expr. clone ( ) ) ) ;
2793
+
2794
+ // Create a binary expression and put the column inside
2795
+ let binary_expr = Arc :: new ( BinaryExpr :: new (
2796
+ is_not_null_expr. clone ( ) ,
2797
+ Operator :: Or ,
2798
+ is_not_null_expr. clone ( ) ,
2799
+ ) ) as Arc < dyn PhysicalExpr > ;
2800
+
2801
+ let fixed_expr =
2802
+ maybe_fix_physical_column_name ( Ok ( binary_expr) , & schema_ref) . unwrap ( ) ;
2803
+
2804
+ let bin = fixed_expr
2805
+ . as_any ( )
2806
+ . downcast_ref :: < BinaryExpr > ( )
2807
+ . expect ( "Expected BinaryExpr" ) ;
2808
+
2809
+ // Check that both sides where renamed
2810
+ for expr in & [ bin. left ( ) , bin. right ( ) ] {
2811
+ let is_not_null = expr
2812
+ . as_any ( )
2813
+ . downcast_ref :: < IsNotNullExpr > ( )
2814
+ . expect ( "Expected IsNotNull" ) ;
2815
+
2816
+ let col = is_not_null
2817
+ . arg ( )
2818
+ . as_any ( )
2819
+ . downcast_ref :: < Column > ( )
2820
+ . expect ( "Expected Column" ) ;
2821
+
2822
+ assert_eq ! ( col. name( ) , "column" ) ;
2823
+ }
2824
+ }
2772
2825
struct ErrorExtensionPlanner { }
2773
2826
2774
2827
#[ async_trait]
0 commit comments