Skip to content

Commit 509ad09

Browse files
xudong963claude
andauthored
Improvement: keep order-preserving repartitions for streaming aggregates (#21107)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> This PR updates `EnforceDistribution` to keep order-preserving repartition variants when preserving input ordering allows the parent operator to remain incremental/streaming. Previously, order-preserving variants could be removed when `prefer_existing_sort = false` or when there was no explicit ordering requirement, even if dropping the ordering would force a parent operator such as `AggregateExec` to fall back to blocking execution. This change adds a targeted `preserving_order_enables_streaming` check and uses it to avoid replacing `RepartitionExec(..., preserve_order=true)` / `SortPreservingMergeExec` when that preserved ordering is what enables streaming behavior. As a result, the optimizer now prefers keeping order-preserving repartitioning in these cases, and the updated sqllogictests reflect the new physical plans: instead of inserting a `SortExec` above a plain repartition, plans now retain `RepartitionExec(... preserve_order=true)` so sorted or partially sorted aggregates can continue running incrementally. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> No extra sort needed for these cases --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 757ce78 commit 509ad09

7 files changed

Lines changed: 233 additions & 62 deletions

File tree

datafusion/core/tests/physical_optimizer/enforce_distribution.rs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ use datafusion_common::tree_node::{
4646
use datafusion_datasource::file_groups::FileGroup;
4747
use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
4848
use datafusion_expr::{JoinType, Operator};
49+
use datafusion_functions_aggregate::count::count_udaf;
50+
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
4951
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, binary, lit};
5052
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
5153
use datafusion_physical_expr_common::sort_expr::{
@@ -462,6 +464,71 @@ fn aggregate_exec_with_alias(
462464
)
463465
}
464466

467+
fn partitioned_count_aggregate_exec(
468+
input: Arc<dyn ExecutionPlan>,
469+
group_alias_pairs: Vec<(String, String)>,
470+
count_column: &str,
471+
) -> Arc<dyn ExecutionPlan> {
472+
let input_schema = input.schema();
473+
let group_by_expr = group_alias_pairs
474+
.iter()
475+
.map(|(column, alias)| {
476+
(
477+
col(column, &input_schema).unwrap() as Arc<dyn PhysicalExpr>,
478+
alias.clone(),
479+
)
480+
})
481+
.collect::<Vec<_>>();
482+
let partial_group_by = PhysicalGroupBy::new_single(group_by_expr.clone());
483+
let final_group_by = PhysicalGroupBy::new_single(
484+
group_by_expr
485+
.iter()
486+
.enumerate()
487+
.map(|(idx, (_expr, alias))| {
488+
(
489+
Arc::new(Column::new(alias, idx)) as Arc<dyn PhysicalExpr>,
490+
alias.clone(),
491+
)
492+
})
493+
.collect::<Vec<_>>(),
494+
);
495+
496+
let aggr_expr = vec![Arc::new(
497+
AggregateExprBuilder::new(
498+
count_udaf(),
499+
vec![col(count_column, &input_schema).unwrap()],
500+
)
501+
.schema(Arc::clone(&input_schema))
502+
.alias(format!("COUNT({count_column})"))
503+
.build()
504+
.unwrap(),
505+
)];
506+
507+
let partial = Arc::new(
508+
AggregateExec::try_new(
509+
AggregateMode::Partial,
510+
partial_group_by,
511+
aggr_expr.clone(),
512+
vec![None],
513+
input,
514+
Arc::clone(&input_schema),
515+
)
516+
.unwrap(),
517+
);
518+
519+
Arc::new(
520+
AggregateExec::try_new(
521+
AggregateMode::FinalPartitioned,
522+
final_group_by,
523+
aggr_expr,
524+
vec![None],
525+
Arc::clone(&partial) as _,
526+
partial.schema(),
527+
)
528+
.unwrap(),
529+
)
530+
}
531+
465532
fn hash_join_exec(
466533
left: Arc<dyn ExecutionPlan>,
467534
right: Arc<dyn ExecutionPlan>,
@@ -3322,6 +3389,71 @@ fn preserve_ordering_through_repartition() -> Result<()> {
33223389
Ok(())
33233390
}
33243391

3392+
#[test]
3393+
fn preserve_ordering_for_streaming_sorted_aggregate() -> Result<()> {
3394+
let schema = schema();
3395+
let sort_key: LexOrdering = [PhysicalSortExpr {
3396+
expr: col("a", &schema)?,
3397+
options: SortOptions::default(),
3398+
}]
3399+
.into();
3400+
let input = parquet_exec_multiple_sorted(vec![sort_key]);
3401+
let physical_plan = partitioned_count_aggregate_exec(
3402+
input,
3403+
vec![("a".to_string(), "a".to_string())],
3404+
"b",
3405+
);
3406+
3407+
let test_config = TestConfig::default().with_query_execution_partitions(2);
3408+
3409+
let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT);
3410+
assert_plan!(plan_distrib, @r"
3411+
AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted
3412+
RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC
3413+
AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted
3414+
DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet
3415+
");
3416+
3417+
let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB);
3418+
assert_plan!(plan_distrib, plan_sort);
3419+
3420+
Ok(())
3421+
}
3422+
3423+
#[test]
3424+
fn preserve_ordering_for_streaming_partially_sorted_aggregate() -> Result<()> {
3425+
let schema = schema();
3426+
let sort_key: LexOrdering = [PhysicalSortExpr {
3427+
expr: col("a", &schema)?,
3428+
options: SortOptions::default(),
3429+
}]
3430+
.into();
3431+
let input = parquet_exec_multiple_sorted(vec![sort_key]);
3432+
let physical_plan = partitioned_count_aggregate_exec(
3433+
input,
3434+
vec![
3435+
("a".to_string(), "a".to_string()),
3436+
("b".to_string(), "b".to_string()),
3437+
],
3438+
"c",
3439+
);
3440+
3441+
let test_config = TestConfig::default().with_query_execution_partitions(2);
3442+
3443+
let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT);
3444+
assert_plan!(plan_distrib, @r"
3445+
AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b], aggr=[COUNT(c)], ordering_mode=PartiallySorted([0])
3446+
RepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC
3447+
AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b], aggr=[COUNT(c)], ordering_mode=PartiallySorted([0])
3448+
DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet
3449+
");
3450+
3451+
let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB);
3452+
assert_plan!(plan_distrib, plan_sort);
3453+
3454+
Ok(())
3455+
}
3456+
33253457
#[test]
33263458
fn do_not_preserve_ordering_through_repartition() -> Result<()> {
33273459
let schema = schema();

datafusion/physical-optimizer/src/enforce_distribution.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,43 @@ fn add_hash_on_top(
928928
///
929929
/// * `input`: Current node.
930930
///
931+
/// Checks whether preserving the child's ordering enables the parent to
932+
/// run in streaming mode. Compares the parent's pipeline behavior with
933+
/// the ordered child vs. an unordered (coalesced) child. If removing the
934+
/// ordering would cause the parent to switch from streaming to blocking,
935+
/// keeping the order-preserving variant is beneficial.
936+
///
937+
/// Only applicable to single-child operators; returns `Ok(false)` for
938+
/// multi-child operators (e.g. joins) where child substitution semantics are
939+
/// ambiguous.
940+
fn preserving_order_enables_streaming(
941+
parent: &Arc<dyn ExecutionPlan>,
942+
ordered_child: &Arc<dyn ExecutionPlan>,
943+
) -> Result<bool> {
944+
// Only applicable to single-child operators that maintain input order
945+
// (e.g. AggregateExec in PartiallySorted mode). Operators that don't
946+
// maintain input order (e.g. SortExec) handle ordering themselves —
947+
// preserving SPM for them is unnecessary.
948+
if parent.children().len() != 1 {
949+
return Ok(false);
950+
}
951+
if !parent.maintains_input_order()[0] {
952+
return Ok(false);
953+
}
954+
// Build parent with the ordered child
955+
let with_ordered =
956+
Arc::clone(parent).with_new_children(vec![Arc::clone(ordered_child)])?;
957+
if with_ordered.pipeline_behavior() == EmissionType::Final {
958+
// Parent is blocking even with ordering — no benefit
959+
return Ok(false);
960+
}
961+
// Build parent with an unordered child via CoalescePartitionsExec.
962+
let unordered_child: Arc<dyn ExecutionPlan> =
963+
Arc::new(CoalescePartitionsExec::new(Arc::clone(ordered_child)));
964+
let without_ordered = Arc::clone(parent).with_new_children(vec![unordered_child])?;
965+
Ok(without_ordered.pipeline_behavior() == EmissionType::Final)
966+
}
967+
931968
/// # Returns
932969
///
933970
/// Updated node with an execution plan, where the desired single distribution
@@ -1340,6 +1377,12 @@ pub fn ensure_distribution(
13401377
}
13411378
};
13421379

1380+
let streaming_benefit = if child.data {
1381+
preserving_order_enables_streaming(&plan, &child.plan)?
1382+
} else {
1383+
false
1384+
};
1385+
13431386
// There is an ordering requirement of the operator:
13441387
if let Some(required_input_ordering) = required_input_ordering {
13451388
// Either:
@@ -1352,6 +1395,7 @@ pub fn ensure_distribution(
13521395
.ordering_satisfy_requirement(sort_req.clone())?;
13531396

13541397
if (!ordering_satisfied || !order_preserving_variants_desirable)
1398+
&& !streaming_benefit
13551399
&& child.data
13561400
{
13571401
child = replace_order_preserving_variants(child)?;
@@ -1372,6 +1416,11 @@ pub fn ensure_distribution(
13721416
// Stop tracking distribution changing operators
13731417
child.data = false;
13741418
} else {
1419+
let streaming_benefit = if child.data {
1420+
preserving_order_enables_streaming(&plan, &child.plan)?
1421+
} else {
1422+
false
1423+
};
13751424
// no ordering requirement
13761425
match requirement {
13771426
// Operator requires specific distribution.
@@ -1380,7 +1429,7 @@ pub fn ensure_distribution(
13801429
// ordering is pointless. However, if it does maintain
13811430
// input order, we keep order-preserving variants so
13821431
// ordering can flow through to ancestors that need it.
1383-
if !maintains {
1432+
if !maintains && !streaming_benefit {
13841433
child = replace_order_preserving_variants(child)?;
13851434
}
13861435
}

datafusion/sqllogictest/test_files/agg_func_substitute.slt

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@ logical_plan
4545
physical_plan
4646
01)ProjectionExec: expr=[a@0 as a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result]
4747
02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
48-
03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]
49-
04)------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
50-
05)--------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
51-
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true
52-
07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], file_type=csv, has_header=true
48+
03)----RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST
49+
04)------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
50+
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true
51+
06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], file_type=csv, has_header=true
5352

5453

5554
query TT
@@ -64,11 +63,10 @@ logical_plan
6463
physical_plan
6564
01)ProjectionExec: expr=[a@0 as a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result]
6665
02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
67-
03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]
68-
04)------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
69-
05)--------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
70-
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true
71-
07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], file_type=csv, has_header=true
66+
03)----RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST
67+
04)------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
68+
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true
69+
06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], file_type=csv, has_header=true
7270

7371
query TT
7472
EXPLAIN SELECT a, ARRAY_AGG(c ORDER BY c)[1 + 100] as result
@@ -82,11 +80,10 @@ logical_plan
8280
physical_plan
8381
01)ProjectionExec: expr=[a@0 as a, nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result]
8482
02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
85-
03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]
86-
04)------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
87-
05)--------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
88-
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true
89-
07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], file_type=csv, has_header=true
83+
03)----RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST
84+
04)------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
85+
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true
86+
06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], file_type=csv, has_header=true
9087

9188
query II
9289
SELECT a, ARRAY_AGG(c ORDER BY c)[1] as result

datafusion/sqllogictest/test_files/group_by.slt

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3971,11 +3971,10 @@ logical_plan
39713971
02)--TableScan: multiple_ordered_table_with_pk projection=[b, c, d]
39723972
physical_plan
39733973
01)AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0])
3974-
02)--SortExec: expr=[c@0 ASC NULLS LAST], preserve_partitioning=[true]
3975-
03)----RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8
3976-
04)------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0])
3977-
05)--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true
3978-
06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], constraints=[PrimaryKey([3])], file_type=csv, has_header=true
3974+
02)--RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@0 ASC NULLS LAST
3975+
03)----AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0])
3976+
04)------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true
3977+
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], constraints=[PrimaryKey([3])], file_type=csv, has_header=true
39793978

39803979
# drop table multiple_ordered_table_with_pk
39813980
statement ok
@@ -4011,11 +4010,10 @@ logical_plan
40114010
02)--TableScan: multiple_ordered_table_with_pk projection=[b, c, d]
40124011
physical_plan
40134012
01)AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0])
4014-
02)--SortExec: expr=[c@0 ASC NULLS LAST], preserve_partitioning=[true]
4015-
03)----RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8
4016-
04)------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0])
4017-
05)--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true
4018-
06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], constraints=[PrimaryKey([3])], file_type=csv, has_header=true
4013+
02)--RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@0 ASC NULLS LAST
4014+
03)----AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0])
4015+
04)------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true
4016+
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], constraints=[PrimaryKey([3])], file_type=csv, has_header=true
40194017

40204018
statement ok
40214019
set datafusion.execution.target_partitions = 1;

0 commit comments

Comments
 (0)