Skip to content

Commit 9333f74

Browse files
adriangbclaude
andauthored
Make PushDownFilter and CommonSubexprEliminate aware of Expr::placement (#20239)
Teaches PushDownFilter to not push down through `ExpressionPlacement::MoveTowardsLeafNodes` using the same approach already in place for volatile expressions. Split out from #20117. --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ac3a68e commit 9333f74

5 files changed

Lines changed: 240 additions & 10 deletions

File tree

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ use datafusion_expr::expr::{Alias, ScalarFunction};
3434
use datafusion_expr::logical_plan::{
3535
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
3636
};
37-
use datafusion_expr::{BinaryExpr, Case, Expr, Operator, SortExpr, col};
37+
use datafusion_expr::{
38+
BinaryExpr, Case, Expr, ExpressionPlacement, Operator, SortExpr, col,
39+
};
3840

3941
const CSE_PREFIX: &str = "__common_expr";
4042

@@ -698,6 +700,18 @@ impl CSEController for ExprCSEController<'_> {
698700
}
699701

700702
fn is_ignored(&self, node: &Expr) -> bool {
703+
// MoveTowardsLeafNodes expressions (e.g. get_field) are cheap struct
704+
// field accesses that the ExtractLeafExpressions / PushDownLeafProjections
705+
// rules deliberately duplicate when needed (one copy for a filter
706+
// predicate, another for an output column). CSE deduplicating them
707+
// creates intermediate projections that fight with those rules,
708+
// causing optimizer instability — ExtractLeafExpressions will undo
709+
// the dedup, creating an infinite loop that runs until the iteration
710+
// limit is hit. Skip them.
711+
if node.placement() == ExpressionPlacement::MoveTowardsLeafNodes {
712+
return true;
713+
}
714+
701715
// TODO: remove the next line after `Expr::Wildcard` is removed
702716
#[expect(deprecated)]
703717
let is_normal_minus_aggregates = matches!(
@@ -830,6 +844,7 @@ mod test {
830844
use super::*;
831845
use crate::assert_optimized_plan_eq_snapshot;
832846
use crate::optimizer::OptimizerContext;
847+
use crate::test::udfs::leaf_udf_expr;
833848
use crate::test::*;
834849
use datafusion_expr::test::function_stub::{avg, sum};
835850

@@ -1831,4 +1846,56 @@ mod test {
18311846
panic!("dummy - not implemented")
18321847
}
18331848
}
1849+
1850+
/// Identical MoveTowardsLeafNodes expressions should NOT be deduplicated
1851+
/// by CSE — they are cheap (e.g. struct field access) and the extraction
1852+
/// rules deliberately duplicate them. Deduplicating causes optimizer
1853+
/// instability where one optimizer rule will undo the work of another,
1854+
/// resulting in an infinite optimization loop until the
1855+
/// we hit the max iteration limit and then give up.
1856+
#[test]
1857+
fn test_leaf_expression_not_extracted() -> Result<()> {
1858+
let table_scan = test_table_scan()?;
1859+
1860+
let leaf = leaf_udf_expr(col("a"));
1861+
let plan = LogicalPlanBuilder::from(table_scan)
1862+
.project(vec![leaf.clone().alias("c1"), leaf.alias("c2")])?
1863+
.build()?;
1864+
1865+
// Plan should be unchanged — no __common_expr introduced
1866+
assert_optimized_plan_equal!(
1867+
plan,
1868+
@r"
1869+
Projection: leaf_udf(test.a) AS c1, leaf_udf(test.a) AS c2
1870+
TableScan: test
1871+
"
1872+
)
1873+
}
1874+
1875+
/// When a MoveTowardsLeafNodes expression appears as a sub-expression of
1876+
/// a larger expression that IS duplicated, only the outer expression gets
1877+
/// deduplicated; the leaf sub-expression stays inline.
1878+
#[test]
1879+
fn test_leaf_subexpression_not_extracted() -> Result<()> {
1880+
let table_scan = test_table_scan()?;
1881+
1882+
// leaf_udf(a) + b appears twice — the outer `+` is a common
1883+
// sub-expression, but leaf_udf(a) by itself is MoveTowardsLeafNodes
1884+
// and should not be extracted separately.
1885+
let common = leaf_udf_expr(col("a")) + col("b");
1886+
let plan = LogicalPlanBuilder::from(table_scan)
1887+
.project(vec![common.clone().alias("c1"), common.alias("c2")])?
1888+
.build()?;
1889+
1890+
// The whole `leaf_udf(a) + b` gets deduplicated as __common_expr_1,
1891+
// but leaf_udf(a) alone is NOT pulled out.
1892+
assert_optimized_plan_equal!(
1893+
plan,
1894+
@r"
1895+
Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1896+
Projection: leaf_udf(test.a) + test.b AS __common_expr_1, test.a, test.b, test.c
1897+
TableScan: test
1898+
"
1899+
)
1900+
}
18341901
}

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ use crate::optimizer::ApplyOrder;
4545
use crate::simplify_expressions::simplify_predicates;
4646
use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
4747
use crate::{OptimizerConfig, OptimizerRule};
48+
use datafusion_expr::ExpressionPlacement;
4849

4950
/// Optimizer rule for pushing (moving) filter expressions down in a plan so
5051
/// they are applied as early as possible.
@@ -1295,10 +1296,13 @@ fn rewrite_projection(
12951296
predicates: Vec<Expr>,
12961297
mut projection: Projection,
12971298
) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1298-
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
1299-
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
1300-
// collect projection.
1301-
let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1299+
// Partition projection expressions into non-pushable vs pushable.
1300+
// Non-pushable expressions are volatile (must not be duplicated) or
1301+
// MoveTowardsLeafNodes (cheap expressions like get_field where re-inlining
1302+
// into a filter causes optimizer instability — ExtractLeafExpressions will
1303+
// undo the push-down, creating an infinite loop that runs until the
1304+
// iteration limit is hit).
1305+
let (non_pushable_map, pushable_map): (HashMap<_, _>, HashMap<_, _>) = projection
13021306
.schema
13031307
.iter()
13041308
.zip(projection.expr.iter())
@@ -1308,12 +1312,15 @@ fn rewrite_projection(
13081312

13091313
(qualified_name(qualifier, field.name()), expr)
13101314
})
1311-
.partition(|(_, value)| value.is_volatile());
1315+
.partition(|(_, value)| {
1316+
value.is_volatile()
1317+
|| value.placement() == ExpressionPlacement::MoveTowardsLeafNodes
1318+
});
13121319

13131320
let mut push_predicates = vec![];
13141321
let mut keep_predicates = vec![];
13151322
for expr in predicates {
1316-
if contain(&expr, &volatile_map) {
1323+
if contain(&expr, &non_pushable_map) {
13171324
keep_predicates.push(expr);
13181325
} else {
13191326
push_predicates.push(expr);
@@ -1325,7 +1332,7 @@ fn rewrite_projection(
13251332
// re-write all filters based on this projection
13261333
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
13271334
let new_filter = LogicalPlan::Filter(Filter::try_new(
1328-
replace_cols_by_name(expr, &non_volatile_map)?,
1335+
replace_cols_by_name(expr, &pushable_map)?,
13291336
std::mem::take(&mut projection.input),
13301337
)?);
13311338

@@ -1336,7 +1343,10 @@ fn rewrite_projection(
13361343
conjunction(keep_predicates),
13371344
))
13381345
}
1339-
None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1346+
None => Ok((
1347+
Transformed::no(LogicalPlan::Projection(projection)),
1348+
conjunction(keep_predicates),
1349+
)),
13401350
}
13411351
}
13421352

@@ -1446,6 +1456,7 @@ mod tests {
14461456
use crate::assert_optimized_plan_eq_snapshot;
14471457
use crate::optimizer::Optimizer;
14481458
use crate::simplify_expressions::SimplifyExpressions;
1459+
use crate::test::udfs::leaf_udf_expr;
14491460
use crate::test::*;
14501461
use datafusion_expr::test::function_stub::sum;
14511462
use insta::assert_snapshot;
@@ -4221,4 +4232,68 @@ mod tests {
42214232
"
42224233
)
42234234
}
4235+
4236+
/// Test that filters are NOT pushed through MoveTowardsLeafNodes projections.
4237+
/// These are cheap expressions (like get_field) where re-inlining into a filter
4238+
/// has no benefit and causes optimizer instability — ExtractLeafExpressions will
4239+
/// undo the push-down, creating an infinite loop that runs until the iteration
4240+
/// limit is hit.
4241+
#[test]
4242+
fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> {
4243+
let table_scan = test_table_scan()?;
4244+
4245+
// Create a projection with a MoveTowardsLeafNodes expression
4246+
let proj = LogicalPlanBuilder::from(table_scan)
4247+
.project(vec![
4248+
leaf_udf_expr(col("a")).alias("val"),
4249+
col("b"),
4250+
col("c"),
4251+
])?
4252+
.build()?;
4253+
4254+
// Put a filter on the MoveTowardsLeafNodes column
4255+
let plan = LogicalPlanBuilder::from(proj)
4256+
.filter(col("val").gt(lit(150i64)))?
4257+
.build()?;
4258+
4259+
// Filter should NOT be pushed through — val maps to a MoveTowardsLeafNodes expr
4260+
assert_optimized_plan_equal!(
4261+
plan,
4262+
@r"
4263+
Filter: val > Int64(150)
4264+
Projection: leaf_udf(test.a) AS val, test.b, test.c
4265+
TableScan: test
4266+
"
4267+
)
4268+
}
4269+
4270+
/// Test mixed predicates: Column predicate pushed, MoveTowardsLeafNodes kept.
4271+
#[test]
4272+
fn filter_mixed_predicates_partial_push() -> Result<()> {
4273+
let table_scan = test_table_scan()?;
4274+
4275+
// Create a projection with both MoveTowardsLeafNodes and Column expressions
4276+
let proj = LogicalPlanBuilder::from(table_scan)
4277+
.project(vec![
4278+
leaf_udf_expr(col("a")).alias("val"),
4279+
col("b"),
4280+
col("c"),
4281+
])?
4282+
.build()?;
4283+
4284+
// Filter with both: val > 150 (MoveTowardsLeafNodes) AND b > 5 (Column)
4285+
let plan = LogicalPlanBuilder::from(proj)
4286+
.filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))?
4287+
.build()?;
4288+
4289+
// val > 150 should be kept above, b > 5 should be pushed through
4290+
assert_optimized_plan_equal!(
4291+
plan,
4292+
@r"
4293+
Filter: val > Int64(150)
4294+
Projection: leaf_udf(test.a) AS val, test.b, test.c
4295+
TableScan: test, full_filters=[test.b > Int64(5)]
4296+
"
4297+
)
4298+
}
42244299
}

datafusion/optimizer/src/test/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use datafusion_common::{Result, assert_contains};
2424
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, logical_plan::table_scan};
2525
use std::sync::Arc;
2626

27+
pub mod udfs;
2728
pub mod user_defined;
2829

2930
pub fn test_table_scan_fields() -> Vec<Field> {
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
20+
use arrow::datatypes::DataType;
21+
use datafusion_common::Result;
22+
use datafusion_expr::{
23+
ColumnarValue, Expr, ExpressionPlacement, ScalarFunctionArgs, ScalarUDF,
24+
ScalarUDFImpl, Signature, Volatility,
25+
};
26+
27+
/// A configurable test UDF for optimizer tests.
28+
/// Defaults to `MoveTowardsLeafNodes` placement. Use `with_placement()` to override.
29+
#[derive(Debug, PartialEq, Eq, Hash)]
30+
pub struct PlacementTestUDF {
31+
signature: Signature,
32+
placement: ExpressionPlacement,
33+
}
34+
35+
impl Default for PlacementTestUDF {
36+
fn default() -> Self {
37+
Self::new()
38+
}
39+
}
40+
41+
impl PlacementTestUDF {
42+
pub fn new() -> Self {
43+
Self {
44+
signature: Signature::exact(vec![DataType::UInt32], Volatility::Immutable),
45+
placement: ExpressionPlacement::MoveTowardsLeafNodes,
46+
}
47+
}
48+
49+
pub fn with_placement(mut self, placement: ExpressionPlacement) -> Self {
50+
self.placement = placement;
51+
self
52+
}
53+
}
54+
55+
impl ScalarUDFImpl for PlacementTestUDF {
56+
fn as_any(&self) -> &dyn Any {
57+
self
58+
}
59+
fn name(&self) -> &str {
60+
match self.placement {
61+
ExpressionPlacement::MoveTowardsLeafNodes => "leaf_udf",
62+
ExpressionPlacement::KeepInPlace => "keep_in_place_udf",
63+
ExpressionPlacement::Column => "column_udf",
64+
ExpressionPlacement::Literal => "literal_udf",
65+
}
66+
}
67+
fn signature(&self) -> &Signature {
68+
&self.signature
69+
}
70+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
71+
Ok(DataType::UInt32)
72+
}
73+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
74+
panic!("PlacementTestUDF: not intended for execution")
75+
}
76+
fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement {
77+
self.placement
78+
}
79+
}
80+
81+
/// Create a `leaf_udf(arg)` expression with `MoveTowardsLeafNodes` placement.
82+
pub fn leaf_udf_expr(arg: Expr) -> Expr {
83+
let udf = ScalarUDF::new_from_impl(
84+
PlacementTestUDF::new().with_placement(ExpressionPlacement::MoveTowardsLeafNodes),
85+
);
86+
udf.call(vec![arg])
87+
}

datafusion/sqllogictest/test_files/projection_pushdown.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ EXPLAIN SELECT id, s['value'], s['value'] + 10, s['label'] FROM simple_struct OR
856856
----
857857
logical_plan
858858
01)Sort: simple_struct.id ASC NULLS LAST, fetch=3
859-
02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")) AS simple_struct.s[value], get_field(simple_struct.s, Utf8("value")) + Int64(10), get_field(simple_struct.s, Utf8("label"))
859+
02)--Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), get_field(simple_struct.s, Utf8("value")) + Int64(10), get_field(simple_struct.s, Utf8("label"))
860860
03)----TableScan: simple_struct projection=[id, s]
861861
physical_plan
862862
01)SortExec: TopK(fetch=3), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false]

0 commit comments

Comments
 (0)