Skip to content

Commit 6d9430d

Browse files
committed
add ut
1 parent e0d3a94 commit 6d9430d

1 file changed

Lines changed: 134 additions & 0 deletions

File tree

datafusion/core/tests/physical_optimizer/enforce_distribution.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ use datafusion_common::tree_node::{
4646
use datafusion_datasource::file_groups::FileGroup;
4747
use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
4848
use datafusion_expr::{JoinType, Operator};
49+
use datafusion_functions_aggregate::count::count_udaf;
50+
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
4951
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, binary, lit};
5052
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
5153
use datafusion_physical_expr_common::sort_expr::{
@@ -462,6 +464,71 @@ fn aggregate_exec_with_alias(
462464
)
463465
}
464466

467+
fn partitioned_count_aggregate_exec(
468+
input: Arc<dyn ExecutionPlan>,
469+
group_alias_pairs: Vec<(String, String)>,
470+
count_column: &str,
471+
) -> Arc<dyn ExecutionPlan> {
472+
let input_schema = input.schema();
473+
let group_by_expr = group_alias_pairs
474+
.iter()
475+
.map(|(column, alias)| {
476+
(
477+
col(column, &input_schema).unwrap() as Arc<dyn PhysicalExpr>,
478+
alias.clone(),
479+
)
480+
})
481+
.collect::<Vec<_>>();
482+
let partial_group_by = PhysicalGroupBy::new_single(group_by_expr.clone());
483+
let final_group_by = PhysicalGroupBy::new_single(
484+
group_by_expr
485+
.iter()
486+
.enumerate()
487+
.map(|(idx, (_expr, alias))| {
488+
(
489+
Arc::new(Column::new(alias, idx)) as Arc<dyn PhysicalExpr>,
490+
alias.clone(),
491+
)
492+
})
493+
.collect::<Vec<_>>(),
494+
);
495+
496+
let aggr_expr = vec![Arc::new(
497+
AggregateExprBuilder::new(
498+
count_udaf(),
499+
vec![col(count_column, &input_schema).unwrap()],
500+
)
501+
.schema(Arc::clone(&input_schema))
502+
.alias(format!("COUNT({count_column})"))
503+
.build()
504+
.unwrap(),
505+
)];
506+
507+
let partial = Arc::new(
508+
AggregateExec::try_new(
509+
AggregateMode::Partial,
510+
partial_group_by,
511+
aggr_expr.clone(),
512+
vec![None],
513+
input,
514+
Arc::clone(&input_schema),
515+
)
516+
.unwrap(),
517+
);
518+
519+
Arc::new(
520+
AggregateExec::try_new(
521+
AggregateMode::FinalPartitioned,
522+
final_group_by,
523+
aggr_expr,
524+
vec![None],
525+
Arc::clone(&partial) as _,
526+
partial.schema(),
527+
)
528+
.unwrap(),
529+
)
530+
}
531+
465532
fn hash_join_exec(
466533
left: Arc<dyn ExecutionPlan>,
467534
right: Arc<dyn ExecutionPlan>,
@@ -3322,6 +3389,73 @@ fn preserve_ordering_through_repartition() -> Result<()> {
33223389
Ok(())
33233390
}
33243391

3392+
#[test]
3393+
fn preserve_ordering_for_streaming_sorted_aggregate() -> Result<()> {
3394+
let schema = schema();
3395+
let sort_key: LexOrdering = [PhysicalSortExpr {
3396+
expr: col("a", &schema)?,
3397+
options: SortOptions::default(),
3398+
}]
3399+
.into();
3400+
let input = parquet_exec_multiple_sorted(vec![sort_key]);
3401+
let physical_plan = partitioned_count_aggregate_exec(
3402+
input,
3403+
vec![("a".to_string(), "a".to_string())],
3404+
"b",
3405+
);
3406+
3407+
let test_config = TestConfig::default().with_query_execution_partitions(2);
3408+
3409+
let plan_distrib =
3410+
test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT);
3411+
assert_plan!(plan_distrib, @r"
3412+
AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted
3413+
RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC
3414+
AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted
3415+
DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet
3416+
");
3417+
3418+
let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB);
3419+
assert_plan!(plan_distrib, plan_sort);
3420+
3421+
Ok(())
3422+
}
3423+
3424+
#[test]
3425+
fn preserve_ordering_for_streaming_partially_sorted_aggregate() -> Result<()> {
3426+
let schema = schema();
3427+
let sort_key: LexOrdering = [PhysicalSortExpr {
3428+
expr: col("a", &schema)?,
3429+
options: SortOptions::default(),
3430+
}]
3431+
.into();
3432+
let input = parquet_exec_multiple_sorted(vec![sort_key]);
3433+
let physical_plan = partitioned_count_aggregate_exec(
3434+
input,
3435+
vec![
3436+
("a".to_string(), "a".to_string()),
3437+
("b".to_string(), "b".to_string()),
3438+
],
3439+
"c",
3440+
);
3441+
3442+
let test_config = TestConfig::default().with_query_execution_partitions(2);
3443+
3444+
let plan_distrib =
3445+
test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT);
3446+
assert_plan!(plan_distrib, @r"
3447+
AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b], aggr=[COUNT(c)], ordering_mode=PartiallySorted([0])
3448+
RepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC
3449+
AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b], aggr=[COUNT(c)], ordering_mode=PartiallySorted([0])
3450+
DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet
3451+
");
3452+
3453+
let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB);
3454+
assert_plan!(plan_distrib, plan_sort);
3455+
3456+
Ok(())
3457+
}
3458+
33253459
#[test]
33263460
fn do_not_preserve_ordering_through_repartition() -> Result<()> {
33273461
let schema = schema();

0 commit comments

Comments
 (0)