Skip to content

Commit 68646e8

Browse files
Dandandanclaude
andcommitted
Use expression tree traversal order instead of sorting for complex expression column collection
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 15bbe4a commit 68646e8

3 files changed

Lines changed: 24 additions & 25 deletions

File tree

datafusion/physical-plan/src/projection.rs

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ use datafusion_execution::TaskContext;
5151
use datafusion_expr::ExpressionPlacement;
5252
use datafusion_physical_expr::equivalence::ProjectionMapping;
5353
use datafusion_physical_expr::projection::Projector;
54-
use datafusion_physical_expr::utils::collect_columns;
5554
use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
5655
use datafusion_physical_expr_common::sort_expr::{
5756
LexOrdering, LexRequirement, PhysicalSortExpr,
@@ -1082,14 +1081,11 @@ fn try_unifying_projections(
10821081
/// Collect all column indices from the given projection expressions.
10831082
fn collect_column_indices(exprs: &[ProjectionExpr]) -> Vec<usize> {
10841083
// Collect column indices in a deterministic order that preserves the
1085-
// projection's column ordering when possible. For simple Column
1086-
// expressions, we use the column index directly (preserving the
1087-
// projection's desired output order). For complex expressions with
1088-
// multiple column references, we sort indices for determinism since
1089-
// collect_columns returns a HashSet with non-deterministic iteration.
1084+
// projection's column ordering. For simple Column expressions, we use
1085+
// the column index directly. For complex expressions, we walk the
1086+
// expression tree to collect column references in traversal order.
10901087
// This allows the embedded projection to match the desired output
1091-
// column order for simple column reorderings, avoiding a residual
1092-
// ProjectionExec.
1088+
// column order, avoiding a residual ProjectionExec.
10931089
let mut seen = std::collections::HashSet::new();
10941090
let mut indices = Vec::new();
10951091
for proj_expr in exprs {
@@ -1099,18 +1095,20 @@ fn collect_column_indices(exprs: &[ProjectionExpr]) -> Vec<usize> {
10991095
indices.push(col.index());
11001096
}
11011097
} else {
1102-
// Complex expression: collect all referenced columns in sorted
1103-
// order for determinism.
1104-
let mut expr_indices: Vec<usize> = collect_columns(&proj_expr.expr)
1105-
.into_iter()
1106-
.map(|c| c.index())
1107-
.collect();
1108-
expr_indices.sort();
1109-
for idx in expr_indices {
1110-
if seen.insert(idx) {
1111-
indices.push(idx);
1112-
}
1113-
}
1098+
// Complex expression: collect all referenced columns in
1099+
// expression tree traversal order (deterministic) to preserve
1100+
// the natural ordering of column references.
1101+
proj_expr
1102+
.expr
1103+
.apply(|expr| {
1104+
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
1105+
if seen.insert(col.index()) {
1106+
indices.push(col.index());
1107+
}
1108+
}
1109+
Ok(TreeNodeRecursion::Continue)
1110+
})
1111+
.expect("closure always returns OK");
11141112
}
11151113
}
11161114
indices
@@ -1226,7 +1224,8 @@ mod tests {
12261224
expr,
12271225
alias: "b-(1+a)".to_string(),
12281226
}]);
1229-
assert_eq!(column_indices, vec![1, 7]);
1227+
// Tree traversal order: b@7 is visited before a@1
1228+
assert_eq!(column_indices, vec![7, 1]);
12301229
Ok(())
12311230
}
12321231

datafusion/sqllogictest/test_files/lateral_join.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,8 @@ logical_plan
527527
physical_plan
528528
01)SortPreservingMergeExec: [id@0 ASC NULLS LAST]
529529
02)--SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true]
530-
03)----ProjectionExec: expr=[id@0 as id, CASE WHEN __always_true@2 IS NULL THEN 0 ELSE cnt@1 END as cnt]
531-
04)------HashJoinExec: mode=CollectLeft, join_type=Left, on=[(id@0, t1_id@1)], projection=[id@0, cnt@1, __always_true@3]
530+
03)----ProjectionExec: expr=[id@0 as id, CASE WHEN __always_true@1 IS NULL THEN 0 ELSE cnt@2 END as cnt]
531+
04)------HashJoinExec: mode=CollectLeft, join_type=Left, on=[(id@0, t1_id@1)], projection=[id@0, __always_true@3, cnt@1]
532532
05)--------DataSourceExec: partitions=1, partition_sizes=[1]
533533
06)--------ProjectionExec: expr=[count(Int64(1))@1 as cnt, t1_id@0 as t1_id, true as __always_true]
534534
07)----------AggregateExec: mode=FinalPartitioned, gby=[t1_id@0 as t1_id], aggr=[count(Int64(1))]

datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ physical_plan
8181
04)------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[sum(profit.amount)]
8282
05)--------RepartitionExec: partitioning=Hash([nation@0, o_year@1], 4), input_partitions=4
8383
06)----------AggregateExec: mode=Partial, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[sum(profit.amount)]
84-
07)------------ProjectionExec: expr=[n_name@0 as nation, date_part(YEAR, o_orderdate@1) as o_year, l_extendedprice@3 * (Some(1),20,0 - l_discount@4) - ps_supplycost@5 * l_quantity@2 as amount]
85-
08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[n_name@7, o_orderdate@5, l_quantity@0, l_extendedprice@1, l_discount@2, ps_supplycost@4]
84+
07)------------ProjectionExec: expr=[n_name@0 as nation, date_part(YEAR, o_orderdate@1) as o_year, l_extendedprice@2 * (Some(1),20,0 - l_discount@3) - ps_supplycost@4 * l_quantity@5 as amount]
85+
08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[n_name@7, o_orderdate@5, l_extendedprice@1, l_discount@2, ps_supplycost@4, l_quantity@0]
8686
09)----------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4
8787
10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_orderkey@0, o_orderkey@0)], projection=[l_quantity@1, l_extendedprice@2, l_discount@3, s_nationkey@4, ps_supplycost@5, o_orderdate@7]
8888
11)--------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4

0 commit comments

Comments
 (0)