Skip to content

Commit 11501ff

Browse files
committed
add ut
1 parent 3f07243 commit 11501ff

1 file changed

Lines changed: 134 additions & 0 deletions

File tree

datafusion/core/tests/physical_optimizer/enforce_distribution.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ use datafusion_common::tree_node::{
4646
use datafusion_datasource::file_groups::FileGroup;
4747
use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
4848
use datafusion_expr::{JoinType, Operator};
49+
use datafusion_functions_aggregate::count::count_udaf;
50+
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
4951
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, binary, lit};
5052
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
5153
use datafusion_physical_expr_common::sort_expr::{
@@ -361,6 +363,71 @@ fn aggregate_exec_with_alias(
361363
)
362364
}
363365

366+
fn partitioned_count_aggregate_exec(
367+
input: Arc<dyn ExecutionPlan>,
368+
group_alias_pairs: Vec<(String, String)>,
369+
count_column: &str,
370+
) -> Arc<dyn ExecutionPlan> {
371+
let input_schema = input.schema();
372+
let group_by_expr = group_alias_pairs
373+
.iter()
374+
.map(|(column, alias)| {
375+
(
376+
col(column, &input_schema).unwrap() as Arc<dyn PhysicalExpr>,
377+
alias.clone(),
378+
)
379+
})
380+
.collect::<Vec<_>>();
381+
let partial_group_by = PhysicalGroupBy::new_single(group_by_expr.clone());
382+
let final_group_by = PhysicalGroupBy::new_single(
383+
group_by_expr
384+
.iter()
385+
.enumerate()
386+
.map(|(idx, (_expr, alias))| {
387+
(
388+
Arc::new(Column::new(alias, idx)) as Arc<dyn PhysicalExpr>,
389+
alias.clone(),
390+
)
391+
})
392+
.collect::<Vec<_>>(),
393+
);
394+
395+
let aggr_expr = vec![Arc::new(
396+
AggregateExprBuilder::new(
397+
count_udaf(),
398+
vec![col(count_column, &input_schema).unwrap()],
399+
)
400+
.schema(Arc::clone(&input_schema))
401+
.alias(format!("COUNT({count_column})"))
402+
.build()
403+
.unwrap(),
404+
)];
405+
406+
let partial = Arc::new(
407+
AggregateExec::try_new(
408+
AggregateMode::Partial,
409+
partial_group_by,
410+
aggr_expr.clone(),
411+
vec![None],
412+
input,
413+
Arc::clone(&input_schema),
414+
)
415+
.unwrap(),
416+
);
417+
418+
Arc::new(
419+
AggregateExec::try_new(
420+
AggregateMode::FinalPartitioned,
421+
final_group_by,
422+
aggr_expr,
423+
vec![None],
424+
Arc::clone(&partial) as _,
425+
partial.schema(),
426+
)
427+
.unwrap(),
428+
)
429+
}
430+
364431
fn hash_join_exec(
365432
left: Arc<dyn ExecutionPlan>,
366433
right: Arc<dyn ExecutionPlan>,
@@ -3221,6 +3288,73 @@ fn preserve_ordering_through_repartition() -> Result<()> {
32213288
Ok(())
32223289
}
32233290

3291+
#[test]
3292+
fn preserve_ordering_for_streaming_sorted_aggregate() -> Result<()> {
3293+
let schema = schema();
3294+
let sort_key: LexOrdering = [PhysicalSortExpr {
3295+
expr: col("a", &schema)?,
3296+
options: SortOptions::default(),
3297+
}]
3298+
.into();
3299+
let input = parquet_exec_multiple_sorted(vec![sort_key]);
3300+
let physical_plan = partitioned_count_aggregate_exec(
3301+
input,
3302+
vec![("a".to_string(), "a".to_string())],
3303+
"b",
3304+
);
3305+
3306+
let test_config = TestConfig::default().with_query_execution_partitions(2);
3307+
3308+
let plan_distrib =
3309+
test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT);
3310+
assert_plan!(plan_distrib, @r"
3311+
AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted
3312+
RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC
3313+
AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted
3314+
DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet
3315+
");
3316+
3317+
let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB);
3318+
assert_plan!(plan_distrib, plan_sort);
3319+
3320+
Ok(())
3321+
}
3322+
3323+
#[test]
3324+
fn preserve_ordering_for_streaming_partially_sorted_aggregate() -> Result<()> {
3325+
let schema = schema();
3326+
let sort_key: LexOrdering = [PhysicalSortExpr {
3327+
expr: col("a", &schema)?,
3328+
options: SortOptions::default(),
3329+
}]
3330+
.into();
3331+
let input = parquet_exec_multiple_sorted(vec![sort_key]);
3332+
let physical_plan = partitioned_count_aggregate_exec(
3333+
input,
3334+
vec![
3335+
("a".to_string(), "a".to_string()),
3336+
("b".to_string(), "b".to_string()),
3337+
],
3338+
"c",
3339+
);
3340+
3341+
let test_config = TestConfig::default().with_query_execution_partitions(2);
3342+
3343+
let plan_distrib =
3344+
test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT);
3345+
assert_plan!(plan_distrib, @r"
3346+
AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b], aggr=[COUNT(c)], ordering_mode=PartiallySorted([0])
3347+
RepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@0 ASC
3348+
AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b], aggr=[COUNT(c)], ordering_mode=PartiallySorted([0])
3349+
DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet
3350+
");
3351+
3352+
let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB);
3353+
assert_plan!(plan_distrib, plan_sort);
3354+
3355+
Ok(())
3356+
}
3357+
32243358
#[test]
32253359
fn do_not_preserve_ordering_through_repartition() -> Result<()> {
32263360
let schema = schema();

0 commit comments

Comments
 (0)