Skip to content

Commit 291990e

Browse files
wip
1 parent c792700 commit 291990e

2 files changed

Lines changed: 295 additions & 8 deletions

File tree

datafusion/core/tests/physical_optimizer/filter_pushdown.rs

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3145,6 +3145,283 @@ fn test_pushdown_with_empty_group_by() {
31453145
);
31463146
}
31473147

3148+
#[test]
3149+
fn test_pushdown_through_aggregate_with_reordered_input_columns() {
3150+
// Test filter pushdown through aggregate when a ProjectionExec below the
3151+
// aggregate reorders columns, causing input indices to differ from output indices.
3152+
// This reproduces the bug where grouping_columns were built from input schema
3153+
// positions rather than output schema positions.
3154+
let scan = TestScanBuilder::new(schema()).with_support(true).build();
3155+
3156+
// Reorder scan output from (a, b, c) to (c, a, b)
3157+
let reordered_schema = Arc::new(Schema::new(vec![
3158+
Field::new("c", DataType::Float64, false),
3159+
Field::new("a", DataType::Utf8, false),
3160+
Field::new("b", DataType::Utf8, false),
3161+
]));
3162+
let projection = Arc::new(
3163+
ProjectionExec::try_new(
3164+
vec![
3165+
(col("c", &schema()).unwrap(), "c".to_string()),
3166+
(col("a", &schema()).unwrap(), "a".to_string()),
3167+
(col("b", &schema()).unwrap(), "b".to_string()),
3168+
],
3169+
scan,
3170+
)
3171+
.unwrap(),
3172+
);
3173+
3174+
let aggregate_expr = vec![
3175+
AggregateExprBuilder::new(
3176+
count_udaf(),
3177+
vec![col("c", &reordered_schema).unwrap()],
3178+
)
3179+
.schema(reordered_schema.clone())
3180+
.alias("cnt")
3181+
.build()
3182+
.map(Arc::new)
3183+
.unwrap(),
3184+
];
3185+
3186+
// Group by a@1, b@2 (input indices in reordered schema)
3187+
let group_by = PhysicalGroupBy::new_single(vec![
3188+
(col("a", &reordered_schema).unwrap(), "a".to_string()),
3189+
(col("b", &reordered_schema).unwrap(), "b".to_string()),
3190+
]);
3191+
3192+
let aggregate = Arc::new(
3193+
AggregateExec::try_new(
3194+
AggregateMode::Final,
3195+
group_by,
3196+
aggregate_expr,
3197+
vec![None],
3198+
projection,
3199+
reordered_schema,
3200+
)
3201+
.unwrap(),
3202+
);
3203+
3204+
// Filter on b@1 in aggregate's OUTPUT schema (a@0, b@1, cnt@2)
3205+
// The grouping expr for b references INPUT index 2, but output index is 1.
3206+
// Without the fix, this filter would be incorrectly blocked.
3207+
let agg_output_schema = aggregate.schema();
3208+
let predicate = col_lit_predicate("b", "bar", &agg_output_schema);
3209+
let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap());
3210+
3211+
insta::assert_snapshot!(
3212+
OptimizationTest::new(plan, FilterPushdown::new(), true),
3213+
@r"
3214+
OptimizationTest:
3215+
input:
3216+
- FilterExec: b@1 = bar
3217+
- AggregateExec: mode=Final, gby=[a@1 as a, b@2 as b], aggr=[cnt]
3218+
- ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b]
3219+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
3220+
output:
3221+
Ok:
3222+
- AggregateExec: mode=Final, gby=[a@1 as a, b@2 as b], aggr=[cnt], ordering_mode=PartiallySorted([1])
3223+
- ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b]
3224+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar
3225+
"
3226+
);
3227+
}
3228+
3229+
#[test]
3230+
fn test_pushdown_through_aggregate_with_reordered_input_no_pushdown_on_agg_result() {
3231+
// Same reordered setup, but filter is on the aggregate result column (cnt),
3232+
// not a grouping column. The filter should NOT be pushed down.
3233+
let scan = TestScanBuilder::new(schema()).with_support(true).build();
3234+
3235+
let reordered_schema = Arc::new(Schema::new(vec![
3236+
Field::new("c", DataType::Float64, false),
3237+
Field::new("a", DataType::Utf8, false),
3238+
Field::new("b", DataType::Utf8, false),
3239+
]));
3240+
let projection = Arc::new(
3241+
ProjectionExec::try_new(
3242+
vec![
3243+
(col("c", &schema()).unwrap(), "c".to_string()),
3244+
(col("a", &schema()).unwrap(), "a".to_string()),
3245+
(col("b", &schema()).unwrap(), "b".to_string()),
3246+
],
3247+
scan,
3248+
)
3249+
.unwrap(),
3250+
);
3251+
3252+
let aggregate_expr = vec![
3253+
AggregateExprBuilder::new(
3254+
count_udaf(),
3255+
vec![col("c", &reordered_schema).unwrap()],
3256+
)
3257+
.schema(reordered_schema.clone())
3258+
.alias("cnt")
3259+
.build()
3260+
.map(Arc::new)
3261+
.unwrap(),
3262+
];
3263+
3264+
let group_by = PhysicalGroupBy::new_single(vec![
3265+
(col("a", &reordered_schema).unwrap(), "a".to_string()),
3266+
(col("b", &reordered_schema).unwrap(), "b".to_string()),
3267+
]);
3268+
3269+
let aggregate = Arc::new(
3270+
AggregateExec::try_new(
3271+
AggregateMode::Final,
3272+
group_by,
3273+
aggregate_expr,
3274+
vec![None],
3275+
projection,
3276+
reordered_schema,
3277+
)
3278+
.unwrap(),
3279+
);
3280+
3281+
// Filter on cnt@2 (aggregate result, not a grouping column)
3282+
let agg_output_schema = aggregate.schema();
3283+
let predicate = Arc::new(BinaryExpr::new(
3284+
Arc::new(Column::new_with_schema("cnt", &agg_output_schema).unwrap()),
3285+
Operator::Gt,
3286+
Arc::new(Literal::new(ScalarValue::Int64(Some(5)))),
3287+
)) as Arc<dyn PhysicalExpr>;
3288+
let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap());
3289+
3290+
insta::assert_snapshot!(
3291+
OptimizationTest::new(plan, FilterPushdown::new(), true),
3292+
@r"
3293+
OptimizationTest:
3294+
input:
3295+
- FilterExec: cnt@2 > 5
3296+
- AggregateExec: mode=Final, gby=[a@1 as a, b@2 as b], aggr=[cnt]
3297+
- ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b]
3298+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
3299+
output:
3300+
Ok:
3301+
- FilterExec: cnt@2 > 5
3302+
- AggregateExec: mode=Final, gby=[a@1 as a, b@2 as b], aggr=[cnt]
3303+
- ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b]
3304+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
3305+
"
3306+
);
3307+
}
3308+
3309+
#[test]
3310+
fn test_pushdown_through_aggregate_grouping_sets_with_reordered_input() {
3311+
// Same reordered projection, but with GROUPING SETS.
3312+
// Verifies the GROUPING SETS path also works with mismatched indices.
3313+
let scan = TestScanBuilder::new(schema()).with_support(true).build();
3314+
3315+
let reordered_schema = Arc::new(Schema::new(vec![
3316+
Field::new("c", DataType::Float64, false),
3317+
Field::new("a", DataType::Utf8, false),
3318+
Field::new("b", DataType::Utf8, false),
3319+
]));
3320+
let projection = Arc::new(
3321+
ProjectionExec::try_new(
3322+
vec![
3323+
(col("c", &schema()).unwrap(), "c".to_string()),
3324+
(col("a", &schema()).unwrap(), "a".to_string()),
3325+
(col("b", &schema()).unwrap(), "b".to_string()),
3326+
],
3327+
scan,
3328+
)
3329+
.unwrap(),
3330+
);
3331+
3332+
let aggregate_expr = vec![
3333+
AggregateExprBuilder::new(
3334+
count_udaf(),
3335+
vec![col("c", &reordered_schema).unwrap()],
3336+
)
3337+
.schema(reordered_schema.clone())
3338+
.alias("cnt")
3339+
.build()
3340+
.map(Arc::new)
3341+
.unwrap(),
3342+
];
3343+
3344+
// GROUPING SETS with (a, b) and (b) — a is missing from the second set
3345+
let group_by = PhysicalGroupBy::new(
3346+
vec![
3347+
(col("a", &reordered_schema).unwrap(), "a".to_string()),
3348+
(col("b", &reordered_schema).unwrap(), "b".to_string()),
3349+
],
3350+
vec![
3351+
(
3352+
Arc::new(Literal::new(ScalarValue::Utf8(None))),
3353+
"a".to_string(),
3354+
),
3355+
(
3356+
Arc::new(Literal::new(ScalarValue::Utf8(None))),
3357+
"b".to_string(),
3358+
),
3359+
],
3360+
vec![
3361+
vec![false, false], // (a, b) - both present
3362+
vec![true, false], // (b) - a is NULL, b present
3363+
],
3364+
true,
3365+
);
3366+
3367+
let aggregate = Arc::new(
3368+
AggregateExec::try_new(
3369+
AggregateMode::Final,
3370+
group_by,
3371+
aggregate_expr,
3372+
vec![None],
3373+
projection,
3374+
reordered_schema,
3375+
)
3376+
.unwrap(),
3377+
);
3378+
3379+
let agg_output_schema = aggregate.schema();
3380+
3381+
// Filter on b (present in ALL grouping sets) → should be pushed down
3382+
let predicate = col_lit_predicate("b", "bar", &agg_output_schema);
3383+
let plan = Arc::new(FilterExec::try_new(predicate, aggregate.clone()).unwrap());
3384+
3385+
insta::assert_snapshot!(
3386+
OptimizationTest::new(plan, FilterPushdown::new(), true),
3387+
@r"
3388+
OptimizationTest:
3389+
input:
3390+
- FilterExec: b@1 = bar
3391+
- AggregateExec: mode=Final, gby=[(a@1 as a, b@2 as b), (NULL as a, b@2 as b)], aggr=[cnt]
3392+
- ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b]
3393+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
3394+
output:
3395+
Ok:
3396+
- AggregateExec: mode=Final, gby=[(a@1 as a, b@2 as b), (NULL as a, b@2 as b)], aggr=[cnt], ordering_mode=PartiallySorted([1])
3397+
- ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b]
3398+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar
3399+
"
3400+
);
3401+
3402+
// Filter on a (missing from second grouping set) → should NOT be pushed down
3403+
let predicate = col_lit_predicate("a", "foo", &agg_output_schema);
3404+
let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap());
3405+
3406+
insta::assert_snapshot!(
3407+
OptimizationTest::new(plan, FilterPushdown::new(), true),
3408+
@r"
3409+
OptimizationTest:
3410+
input:
3411+
- FilterExec: a@0 = foo
3412+
- AggregateExec: mode=Final, gby=[(a@1 as a, b@2 as b), (NULL as a, b@2 as b)], aggr=[cnt]
3413+
- ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b]
3414+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
3415+
output:
3416+
Ok:
3417+
- FilterExec: a@0 = foo
3418+
- AggregateExec: mode=Final, gby=[(a@1 as a, b@2 as b), (NULL as a, b@2 as b)], aggr=[cnt]
3419+
- ProjectionExec: expr=[c@2 as c, a@0 as a, b@1 as b]
3420+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
3421+
"
3422+
);
3423+
}
3424+
31483425
#[test]
31493426
fn test_pushdown_with_computed_grouping_key() {
31503427
// Test filter pushdown with computed grouping expression

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,11 +1473,15 @@ impl ExecutionPlan for AggregateExec {
14731473
// This optimization is NOT safe for filters on aggregated columns (like filtering on
14741474
// the result of SUM or COUNT), as those require computing all groups first.
14751475

1476-
let grouping_columns: HashSet<_> = self
1477-
.group_by
1478-
.expr()
1479-
.iter()
1480-
.flat_map(|(expr, _)| collect_columns(expr))
1476+
// Build grouping columns using OUTPUT indices, not input indices.
1477+
// Parent filters reference the AggregateExec's output schema where grouping
1478+
// columns occupy positions [0..num_groups). The grouping expressions reference
1479+
// INPUT columns which may have different indices (e.g., when an intermediate
1480+
// ProjectionExec reorders columns). We must compare in the same index space.
1481+
let output_schema = self.schema();
1482+
let num_grouping_cols = self.group_by.expr().len();
1483+
let grouping_columns: HashSet<_> = (0..num_grouping_cols)
1484+
.map(|i| Column::new(output_schema.field(i).name(), i))
14811485
.collect();
14821486

14831487
// Analyze each filter separately to determine if it can be pushed down
@@ -1499,12 +1503,18 @@ impl ExecutionPlan for AggregateExec {
14991503

15001504
// For GROUPING SETS, verify this filter's columns appear in all grouping sets
15011505
if self.group_by.groups().len() > 1 {
1506+
// Map filter columns to their grouping expression index via output position
15021507
let filter_column_indices: Vec<usize> = filter_columns
15031508
.iter()
15041509
.filter_map(|filter_col| {
1505-
self.group_by.expr().iter().position(|(expr, _)| {
1506-
collect_columns(expr).contains(filter_col)
1507-
})
1510+
if filter_col.index() < num_grouping_cols
1511+
&& output_schema.field(filter_col.index()).name()
1512+
== filter_col.name()
1513+
{
1514+
Some(filter_col.index())
1515+
} else {
1516+
None
1517+
}
15081518
})
15091519
.collect();
15101520

0 commit comments

Comments
 (0)