@@ -28,11 +28,12 @@ use crate::physical_optimizer::test_utils::{
2828use arrow:: datatypes:: DataType ;
2929use arrow:: { compute:: SortOptions , util:: pretty:: pretty_format_batches} ;
3030use datafusion:: prelude:: SessionContext ;
31- use datafusion_common:: Result ;
31+ use datafusion_common:: { Result , config :: ConfigOptions } ;
3232use datafusion_execution:: config:: SessionConfig ;
3333use datafusion_expr:: Operator ;
3434use datafusion_physical_expr:: expressions:: { self , cast, col} ;
3535use datafusion_physical_expr_common:: sort_expr:: PhysicalSortExpr ;
36+ use datafusion_physical_optimizer:: PhysicalOptimizerRule ;
3637use datafusion_physical_plan:: {
3738 ExecutionPlan ,
3839 aggregates:: { AggregateExec , AggregateMode } ,
@@ -332,7 +333,7 @@ fn test_has_aggregate_expression() -> Result<()> {
332333 let schema = source. schema ( ) ;
333334 let agg = TestAggregate :: new_count_star ( ) ;
334335
335- // `SELECT <aggregate with no expressions> FROM DataSourceExec LIMIT 10;`, Single AggregateExec
336+ // `SELECT a, COUNT(*) FROM DataSourceExec GROUP BY a LIMIT 10;`, Single AggregateExec
336337 let single_agg = AggregateExec :: try_new (
337338 AggregateMode :: Single ,
338339 build_group_by ( & schema, vec ! [ "a" . to_string( ) ] ) ,
@@ -345,21 +346,82 @@ fn test_has_aggregate_expression() -> Result<()> {
345346 Arc :: new ( single_agg) ,
346347 10 , // fetch
347348 ) ;
348- // expected not to push the limit to the AggregateExec
349+ // expected to push the limit to the AggregateExec
349350 let plan: Arc < dyn ExecutionPlan > = Arc :: new ( limit_exec) ;
350351 let formatted = get_optimized_plan ( & plan) ?;
351352 let actual = formatted. trim ( ) ;
352353 assert_snapshot ! (
353354 actual,
354355 @r"
355356 LocalLimitExec: fetch=10
356- AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]
357+ AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)], lim=[10]
357358 DataSourceExec: partitions=1, partition_sizes=[1]
358359 "
359360 ) ;
360361 Ok ( ( ) )
361362}
362363
364+ #[ tokio:: test]
365+ async fn test_partial_final_with_aggregate_expression ( ) -> Result < ( ) > {
366+ let source = mock_data ( ) ?;
367+ let schema = source. schema ( ) ;
368+ let agg = TestAggregate :: new_count_star ( ) ;
369+
370+ // `SELECT a, COUNT(*) FROM DataSourceExec GROUP BY a LIMIT 4;`,
371+ // Partial/Final AggregateExec. Both stages can keep the same deterministic
372+ // top-k group keys.
373+ let partial_agg = AggregateExec :: try_new (
374+ AggregateMode :: Partial ,
375+ build_group_by ( & schema. clone ( ) , vec ! [ "a" . to_string( ) ] ) ,
376+ vec ! [ Arc :: new( agg. count_expr( & schema) ) ] , /* aggr_expr */
377+ vec ! [ None ] , /* filter_expr */
378+ source, /* input */
379+ schema. clone ( ) , /* input_schema */
380+ ) ?;
381+ let final_agg = AggregateExec :: try_new (
382+ AggregateMode :: Final ,
383+ build_group_by ( & schema. clone ( ) , vec ! [ "a" . to_string( ) ] ) ,
384+ vec ! [ Arc :: new( agg. count_expr( & schema) ) ] , /* aggr_expr */
385+ vec ! [ None ] , /* filter_expr */
386+ Arc :: new ( partial_agg) , /* input */
387+ schema. clone ( ) , /* input_schema */
388+ ) ?;
389+ let limit_exec = LocalLimitExec :: new (
390+ Arc :: new ( final_agg) ,
391+ 4 , // fetch
392+ ) ;
393+ let plan: Arc < dyn ExecutionPlan > = Arc :: new ( limit_exec) ;
394+ let formatted = get_optimized_plan ( & plan) ?;
395+ let actual = formatted. trim ( ) ;
396+ assert_snapshot ! (
397+ actual,
398+ @r"
399+ LocalLimitExec: fetch=4
400+ AggregateExec: mode=Final, gby=[a@0 as a], aggr=[COUNT(*)], lim=[4]
401+ AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[COUNT(*)], lim=[4]
402+ DataSourceExec: partitions=1, partition_sizes=[1]
403+ "
404+ ) ;
405+ let optimized =
406+ datafusion_physical_optimizer:: limited_distinct_aggregation:: LimitedDistinctAggregation :: new ( )
407+ . optimize ( Arc :: clone ( & plan) , & ConfigOptions :: new ( ) ) ?;
408+ let expected = run_plan_and_format ( optimized) . await ?;
409+ assert_snapshot ! (
410+ expected,
411+ @r"
412+ +---+----------+
413+ | a | COUNT(*) |
414+ +---+----------+
415+ | | 1 |
416+ | 1 | 2 |
417+ | 2 | 1 |
418+ | 4 | 1 |
419+ +---+----------+
420+ "
421+ ) ;
422+ Ok ( ( ) )
423+ }
424+
363425#[ test]
364426fn test_has_filter ( ) -> Result < ( ) > {
365427 let source = mock_data ( ) ?;
0 commit comments