@@ -1036,7 +1036,19 @@ impl GroupedHashAggregateStream {
10361036 self . group_values . len ( )
10371037 } ;
10381038
1039- if let Some ( batch) = self . emit ( EmitTo :: First ( n) , false ) ? {
1039+ // Clamp to the sort boundary when using partial group ordering,
1040+ // otherwise remove_groups panics (#20445).
1041+ let n = match & self . group_ordering {
1042+ GroupOrdering :: None => n,
1043+ _ => match self . group_ordering . emit_to ( ) {
1044+ Some ( EmitTo :: First ( max) ) => n. min ( max) ,
1045+ _ => 0 ,
1046+ } ,
1047+ } ;
1048+
1049+ if n > 0
1050+ && let Some ( batch) = self . emit ( EmitTo :: First ( n) , false ) ?
1051+ {
10401052 Ok ( Some ( ExecutionState :: ProducingOutput ( batch) ) )
10411053 } else {
10421054 Err ( oom)
@@ -1291,6 +1303,7 @@ impl GroupedHashAggregateStream {
12911303#[ cfg( test) ]
12921304mod tests {
12931305 use super :: * ;
1306+ use crate :: InputOrderMode ;
12941307 use crate :: execution_plan:: ExecutionPlan ;
12951308 use crate :: test:: TestMemoryExec ;
12961309 use arrow:: array:: { Int32Array , Int64Array } ;
@@ -1553,4 +1566,88 @@ mod tests {
15531566
15541567 Ok ( ( ) )
15551568 }
1569+
1570+ #[ tokio:: test]
1571+ async fn test_emit_early_with_partially_sorted ( ) -> Result < ( ) > {
1572+ // Reproducer for #20445: EmitEarly with PartiallySorted panics in
1573+ // remove_groups because it emits more groups than the sort boundary.
1574+ let schema = Arc :: new ( Schema :: new ( vec ! [
1575+ Field :: new( "sort_col" , DataType :: Int32 , false ) ,
1576+ Field :: new( "group_col" , DataType :: Int32 , false ) ,
1577+ Field :: new( "value_col" , DataType :: Int64 , false ) ,
1578+ ] ) ) ;
1579+
1580+ // All rows share sort_col=1 (no sort boundary), with unique group_col
1581+ // values to create many groups and trigger memory pressure.
1582+ let n = 256 ;
1583+ let batch = RecordBatch :: try_new (
1584+ Arc :: clone ( & schema) ,
1585+ vec ! [
1586+ Arc :: new( Int32Array :: from( vec![ 1 ; n] ) ) ,
1587+ Arc :: new( Int32Array :: from( ( 0 ..n as i32 ) . collect:: <Vec <_>>( ) ) ) ,
1588+ Arc :: new( Int64Array :: from( vec![ 1 ; n] ) ) ,
1589+ ] ,
1590+ ) ?;
1591+
1592+ let runtime = RuntimeEnvBuilder :: default ( )
1593+ . with_memory_limit ( 4096 , 1.0 )
1594+ . build_arc ( ) ?;
1595+ let mut task_ctx = TaskContext :: default ( ) . with_runtime ( runtime) ;
1596+ let mut cfg = task_ctx. session_config ( ) . clone ( ) ;
1597+ cfg = cfg. set (
1598+ "datafusion.execution.batch_size" ,
1599+ & datafusion_common:: ScalarValue :: UInt64 ( Some ( 128 ) ) ,
1600+ ) ;
1601+ cfg = cfg. set (
1602+ "datafusion.execution.skip_partial_aggregation_probe_rows_threshold" ,
1603+ & datafusion_common:: ScalarValue :: UInt64 ( Some ( u64:: MAX ) ) ,
1604+ ) ;
1605+ task_ctx = task_ctx. with_session_config ( cfg) ;
1606+ let task_ctx = Arc :: new ( task_ctx) ;
1607+
1608+ let ordering = LexOrdering :: new ( vec ! [ PhysicalSortExpr :: new_default( Arc :: new(
1609+ Column :: new( "sort_col" , 0 ) ,
1610+ )
1611+ as _) ] )
1612+ . unwrap ( ) ;
1613+ let exec = TestMemoryExec :: try_new ( & [ vec ! [ batch] ] , Arc :: clone ( & schema) , None ) ?
1614+ . try_with_sort_information ( vec ! [ ordering] ) ?;
1615+ let exec = Arc :: new ( TestMemoryExec :: update_cache ( & Arc :: new ( exec) ) ) ;
1616+
1617+ // GROUP BY sort_col, group_col with input sorted on sort_col
1618+ // gives PartiallySorted([0])
1619+ let aggregate_exec = AggregateExec :: try_new (
1620+ AggregateMode :: Partial ,
1621+ PhysicalGroupBy :: new_single ( vec ! [
1622+ ( col( "sort_col" , & schema) ?, "sort_col" . to_string( ) ) ,
1623+ ( col( "group_col" , & schema) ?, "group_col" . to_string( ) ) ,
1624+ ] ) ,
1625+ vec ! [ Arc :: new(
1626+ AggregateExprBuilder :: new( count_udaf( ) , vec![ col( "value_col" , & schema) ?] )
1627+ . schema( Arc :: clone( & schema) )
1628+ . alias( "count_value" )
1629+ . build( ) ?,
1630+ ) ] ,
1631+ vec ! [ None ] ,
1632+ exec,
1633+ Arc :: clone ( & schema) ,
1634+ ) ?;
1635+ assert ! ( matches!(
1636+ aggregate_exec. input_order_mode( ) ,
1637+ InputOrderMode :: PartiallySorted ( _)
1638+ ) ) ;
1639+
1640+ // Must not panic with "assertion failed: *current_sort >= n"
1641+ let mut stream = GroupedHashAggregateStream :: new ( & aggregate_exec, & task_ctx, 0 ) ?;
1642+ while let Some ( result) = stream. next ( ) . await {
1643+ if let Err ( e) = result {
1644+ if e. to_string ( ) . contains ( "Resources exhausted" ) {
1645+ break ;
1646+ }
1647+ return Err ( e) ;
1648+ }
1649+ }
1650+
1651+ Ok ( ( ) )
1652+ }
15561653}
0 commit comments