@@ -55,12 +55,12 @@ use datafusion_common::{
5555use datafusion_execution:: TaskContext ;
5656use datafusion_expr:: Operator ;
5757use datafusion_physical_expr:: equivalence:: ProjectionMapping ;
58- use datafusion_physical_expr:: expressions:: { BinaryExpr , Column , lit} ;
58+ use datafusion_physical_expr:: expressions:: { BinaryExpr , Column , Literal , lit} ;
5959use datafusion_physical_expr:: intervals:: utils:: check_support;
6060use datafusion_physical_expr:: utils:: { collect_columns, reassign_expr_columns} ;
6161use datafusion_physical_expr:: {
62- AcrossPartitions , AnalysisContext , ConstExpr , ExprBoundaries , PhysicalExpr , analyze ,
63- conjunction, split_conjunction,
62+ AcrossPartitions , AnalysisContext , ConstExpr , EquivalenceProperties , ExprBoundaries ,
63+ PhysicalExpr , analyze , conjunction, split_conjunction,
6464} ;
6565
6666use datafusion_physical_expr_common:: physical_expr:: fmt_sql;
@@ -243,6 +243,20 @@ impl FilterExec {
243243 } )
244244 }
245245
246+ /// Returns the `AcrossPartitions` value for `expr` if it is constant:
247+ /// either already known constant in `input_eqs`, or a `Literal`
248+ /// (which is inherently constant across all partitions).
249+ fn expr_constant_or_literal (
250+ expr : & Arc < dyn PhysicalExpr > ,
251+ input_eqs : & EquivalenceProperties ,
252+ ) -> Option < AcrossPartitions > {
253+ input_eqs. is_expr_constant ( expr) . or_else ( || {
254+ expr. as_any ( )
255+ . downcast_ref :: < Literal > ( )
256+ . map ( |l| AcrossPartitions :: Uniform ( Some ( l. value ( ) . clone ( ) ) ) )
257+ } )
258+ }
259+
246260 fn extend_constants (
247261 input : & Arc < dyn ExecutionPlan > ,
248262 predicate : & Arc < dyn PhysicalExpr > ,
@@ -255,18 +269,24 @@ impl FilterExec {
255269 if let Some ( binary) = conjunction. as_any ( ) . downcast_ref :: < BinaryExpr > ( )
256270 && binary. op ( ) == & Operator :: Eq
257271 {
258- // Filter evaluates to single value for all partitions
259- if input_eqs. is_expr_constant ( binary. left ( ) ) . is_some ( ) {
260- let across = input_eqs
261- . is_expr_constant ( binary. right ( ) )
262- . unwrap_or_default ( ) ;
272+ // Check if either side is constant — either already known
273+ // constant from the input equivalence properties, or a literal
274+ // value (which is inherently constant across all partitions).
275+ let left_const = Self :: expr_constant_or_literal ( binary. left ( ) , input_eqs) ;
276+ let right_const =
277+ Self :: expr_constant_or_literal ( binary. right ( ) , input_eqs) ;
278+
279+ if let Some ( left_across) = left_const {
280+ // LEFT is constant, so RIGHT must also be constant.
281+ // Use RIGHT's known across value if available, otherwise
282+ // propagate LEFT's (e.g. Uniform from a literal).
283+ let across = right_const. unwrap_or ( left_across) ;
263284 res_constants
264285 . push ( ConstExpr :: new ( Arc :: clone ( binary. right ( ) ) , across) ) ;
265- } else if input_eqs. is_expr_constant ( binary. right ( ) ) . is_some ( ) {
266- let across = input_eqs
267- . is_expr_constant ( binary. left ( ) )
268- . unwrap_or_default ( ) ;
269- res_constants. push ( ConstExpr :: new ( Arc :: clone ( binary. left ( ) ) , across) ) ;
286+ } else if let Some ( right_across) = right_const {
287+ // RIGHT is constant, so LEFT must also be constant.
288+ res_constants
289+ . push ( ConstExpr :: new ( Arc :: clone ( binary. left ( ) ) , right_across) ) ;
270290 }
271291 }
272292 }
@@ -866,6 +886,19 @@ fn collect_columns_from_predicate_inner(
866886 let predicates = split_conjunction ( predicate) ;
867887 predicates. into_iter ( ) . for_each ( |p| {
868888 if let Some ( binary) = p. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) {
889+ // Only extract pairs where at least one side is a Column reference.
890+ // Pairs like `complex_expr = literal` should not create equivalence
891+ // classes — the literal could appear in many unrelated expressions
892+ // (e.g. sort keys), and normalize_expr's deep traversal would
893+ // replace those occurrences with the complex expression, corrupting
894+ // sort orderings. Constant propagation for such pairs is handled
895+ // separately by `extend_constants`.
896+ let has_direct_column_operand =
897+ binary. left ( ) . as_any ( ) . downcast_ref :: < Column > ( ) . is_some ( )
898+ || binary. right ( ) . as_any ( ) . downcast_ref :: < Column > ( ) . is_some ( ) ;
899+ if !has_direct_column_operand {
900+ return ;
901+ }
869902 match binary. op ( ) {
870903 Operator :: Eq => {
871904 eq_predicate_columns. push ( ( binary. left ( ) , binary. right ( ) ) )
@@ -1700,6 +1733,47 @@ mod tests {
17001733 from output schema (c@0) to input schema (c@2)"
17011734 ) ;
17021735
1736+ Ok ( ( ) )
1737+ }
1738+ /// Regression test for https://github.com/apache/datafusion/issues/20194
1739+ ///
1740+ /// `collect_columns_from_predicate_inner` should only extract equality
1741+ /// pairs where at least one side is a Column. Pairs like
1742+ /// `complex_expr = literal` must not create equivalence classes because
1743+ /// `normalize_expr`'s deep traversal would replace the literal inside
1744+ /// unrelated expressions (e.g. sort keys) with the complex expression.
1745+ #[ test]
1746+ fn test_collect_columns_skips_non_column_pairs ( ) -> Result < ( ) > {
1747+ let schema = test:: aggr_test_schema ( ) ;
1748+
1749+ // Simulate: nvl(c2, 0) = 0 → (c2 IS DISTINCT FROM 0) = 0
1750+ // Neither side is a Column, so this should NOT be extracted.
1751+ let complex_expr: Arc < dyn PhysicalExpr > = binary (
1752+ col ( "c2" , & schema) ?,
1753+ Operator :: IsDistinctFrom ,
1754+ lit ( 0u32 ) ,
1755+ & schema,
1756+ ) ?;
1757+ let predicate: Arc < dyn PhysicalExpr > =
1758+ binary ( complex_expr, Operator :: Eq , lit ( 0u32 ) , & schema) ?;
1759+
1760+ let ( equal_pairs, _) = collect_columns_from_predicate_inner ( & predicate) ;
1761+ assert_eq ! (
1762+ 0 ,
1763+ equal_pairs. len( ) ,
1764+ "Should not extract equality pairs where neither side is a Column"
1765+ ) ;
1766+
1767+ // But col = literal should still be extracted
1768+ let predicate: Arc < dyn PhysicalExpr > =
1769+ binary ( col ( "c2" , & schema) ?, Operator :: Eq , lit ( 0u32 ) , & schema) ?;
1770+ let ( equal_pairs, _) = collect_columns_from_predicate_inner ( & predicate) ;
1771+ assert_eq ! (
1772+ 1 ,
1773+ equal_pairs. len( ) ,
1774+ "Should extract equality pairs where one side is a Column"
1775+ ) ;
1776+
17031777 Ok ( ( ) )
17041778 }
17051779}
0 commit comments