@@ -1047,6 +1047,34 @@ impl AggregateExec {
10471047 & self . input_order_mode
10481048 }
10491049
1050+ /// Returns the dynamic filter expression for this aggregate, if set.
1051+ pub fn dynamic_filter ( & self ) -> Option < & Arc < DynamicFilterPhysicalExpr > > {
1052+ self . dynamic_filter . as_ref ( ) . map ( |df| & df. filter )
1053+ }
1054+
1055+ /// Replace the dynamic filter expression, recomputing any internal state
1056+ /// which may depend on the previous dynamic filter.
1057+ ///
1058+ /// This is a no-op if the aggregate does not support dynamic filtering.
1059+ ///
1060+ /// If dynamic filtering is supported, this method returns an error if the filter's
1061+ /// children reference invalid columns in the aggregate's input schema.
1062+ pub fn with_dynamic_filter (
1063+ mut self ,
1064+ filter : Arc < DynamicFilterPhysicalExpr > ,
1065+ ) -> Result < Self > {
1066+ if let Some ( supported_accumulators_info) = self . supported_accumulators_info ( ) {
1067+ for child in filter. children ( ) {
1068+ child. data_type ( & self . input_schema ) ?;
1069+ }
1070+ self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1071+ filter,
1072+ supported_accumulators_info,
1073+ } ) ) ;
1074+ }
1075+ Ok ( self )
1076+ }
1077+
10501078 /// Estimates output statistics for this aggregate node.
10511079 ///
10521080 /// For grouped aggregations with known input row count > 1, the output row
@@ -1229,27 +1257,40 @@ impl AggregateExec {
12291257 /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field.
12301258 /// - If not supported, `self.dynamic_filter` should be kept `None`
12311259 fn init_dynamic_filter ( & mut self ) {
1232- if ( !self . group_by . is_empty ( ) ) || ( self . mode != AggregateMode :: Partial ) {
1233- debug_assert ! (
1234- self . dynamic_filter. is_none( ) ,
1235- "The current operator node does not support dynamic filter"
1236- ) ;
1237- return ;
1238- }
1239-
12401260 // Already initialized.
12411261 if self . dynamic_filter . is_some ( ) {
12421262 return ;
12431263 }
12441264
1245- // Collect supported accumulators
1246- // It is assumed the order of aggregate expressions are not changed from `AggregateExec`
1247- // to `AggregateStream`
1265+ if let Some ( supported_accumulators_info) = self . supported_accumulators_info ( ) {
1266+ // Collect column references for the dynamic filter expression.
1267+ let all_cols: Vec < Arc < dyn PhysicalExpr > > = supported_accumulators_info
1268+ . iter ( )
1269+ . map ( |info| Arc :: clone ( & self . aggr_expr [ info. aggr_index ] . expressions ( ) [ 0 ] ) )
1270+ . collect ( ) ;
1271+
1272+ self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1273+ filter : Arc :: new ( DynamicFilterPhysicalExpr :: new ( all_cols, lit ( true ) ) ) ,
1274+ supported_accumulators_info,
1275+ } ) ) ;
1276+ }
1277+ }
1278+
1279+ /// Returns the supported accumulator info if this aggregate supports
1280+ /// dynamic filtering, or `None` otherwise.
1281+ ///
1282+ /// Dynamic filtering requires:
1283+ /// - `Partial` aggregation mode with no group-by expressions
1284+ /// - All aggregate functions are `min` or `max` with a single column arg
1285+ fn supported_accumulators_info ( & self ) -> Option < Vec < PerAccumulatorDynFilter > > {
1286+ if !self . group_by . is_empty ( ) || !matches ! ( self . mode, AggregateMode :: Partial ) {
1287+ return None ;
1288+ }
1289+
1290+ // Collect supported accumulators.
1291+ // It is assumed the order of aggregate expressions are not changed
1292+ // from `AggregateExec` to `AggregateStream`.
12481293 let mut aggr_dyn_filters = Vec :: new ( ) ;
1249- // All column references in the dynamic filter, used when initializing the dynamic
1250- // filter, and it's used to decide if this dynamic filter is able to get push
1251- // through certain node during optimization.
1252- let mut all_cols: Vec < Arc < dyn PhysicalExpr > > = Vec :: new ( ) ;
12531294 for ( i, aggr_expr) in self . aggr_expr . iter ( ) . enumerate ( ) {
12541295 // 1. Only `min` or `max` aggregate function
12551296 let fun_name = aggr_expr. fun ( ) . name ( ) ;
@@ -1260,14 +1301,13 @@ impl AggregateExec {
12601301 } else if fun_name. eq_ignore_ascii_case ( "max" ) {
12611302 DynamicFilterAggregateType :: Max
12621303 } else {
1263- return ;
1304+ return None ;
12641305 } ;
12651306
12661307 // 2. arg should be only 1 column reference
12671308 if let [ arg] = aggr_expr. expressions ( ) . as_slice ( )
12681309 && arg. is :: < Column > ( )
12691310 {
1270- all_cols. push ( Arc :: clone ( arg) ) ;
12711311 aggr_dyn_filters. push ( PerAccumulatorDynFilter {
12721312 aggr_type,
12731313 aggr_index : i,
@@ -1276,11 +1316,10 @@ impl AggregateExec {
12761316 }
12771317 }
12781318
1279- if !aggr_dyn_filters. is_empty ( ) {
1280- self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1281- filter : Arc :: new ( DynamicFilterPhysicalExpr :: new ( all_cols, lit ( true ) ) ) ,
1282- supported_accumulators_info : aggr_dyn_filters,
1283- } ) )
1319+ if aggr_dyn_filters. is_empty ( ) {
1320+ None
1321+ } else {
1322+ Some ( aggr_dyn_filters)
12841323 }
12851324 }
12861325
@@ -2177,6 +2216,7 @@ mod tests {
21772216 use crate :: coalesce_partitions:: CoalescePartitionsExec ;
21782217 use crate :: common;
21792218 use crate :: common:: collect;
2219+ use crate :: empty:: EmptyExec ;
21802220 use crate :: execution_plan:: Boundedness ;
21812221 use crate :: expressions:: col;
21822222 use crate :: metrics:: MetricValue ;
@@ -2202,6 +2242,7 @@ mod tests {
22022242 use datafusion_functions_aggregate:: count:: count_udaf;
22032243 use datafusion_functions_aggregate:: first_last:: { first_value_udaf, last_value_udaf} ;
22042244 use datafusion_functions_aggregate:: median:: median_udaf;
2245+ use datafusion_functions_aggregate:: min_max:: min_udaf;
22052246 use datafusion_functions_aggregate:: sum:: sum_udaf;
22062247 use datafusion_physical_expr:: Partitioning ;
22072248 use datafusion_physical_expr:: PhysicalSortExpr ;
@@ -3682,13 +3723,10 @@ mod tests {
36823723 // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
36833724 let aggregates: Vec < Arc < AggregateFunctionExpr > > = vec ! [
36843725 Arc :: new(
3685- AggregateExprBuilder :: new(
3686- datafusion_functions_aggregate:: min_max:: min_udaf( ) ,
3687- vec![ col( "b" , & schema) ?] ,
3688- )
3689- . schema( Arc :: clone( & schema) )
3690- . alias( "MIN(b)" )
3691- . build( ) ?,
3726+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "b" , & schema) ?] )
3727+ . schema( Arc :: clone( & schema) )
3728+ . alias( "MIN(b)" )
3729+ . build( ) ?,
36923730 ) ,
36933731 Arc :: new(
36943732 AggregateExprBuilder :: new( avg_udaf( ) , vec![ col( "b" , & schema) ?] )
@@ -3827,13 +3865,10 @@ mod tests {
38273865 // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
38283866 let aggregates: Vec < Arc < AggregateFunctionExpr > > = vec ! [
38293867 Arc :: new(
3830- AggregateExprBuilder :: new(
3831- datafusion_functions_aggregate:: min_max:: min_udaf( ) ,
3832- vec![ col( "b" , & schema) ?] ,
3833- )
3834- . schema( Arc :: clone( & schema) )
3835- . alias( "MIN(b)" )
3836- . build( ) ?,
3868+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "b" , & schema) ?] )
3869+ . schema( Arc :: clone( & schema) )
3870+ . alias( "MIN(b)" )
3871+ . build( ) ?,
38373872 ) ,
38383873 Arc :: new(
38393874 AggregateExprBuilder :: new( avg_udaf( ) , vec![ col( "b" , & schema) ?] )
@@ -4781,4 +4816,116 @@ mod tests {
47814816
47824817 Ok ( ( ) )
47834818 }
4819+
4820+ #[ test]
4821+ fn test_with_dynamic_filter ( ) -> Result < ( ) > {
4822+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ) ;
4823+ let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
4824+
4825+ // Partial min triggers init_dynamic_filter.
4826+ let agg = AggregateExec :: try_new (
4827+ AggregateMode :: Partial ,
4828+ PhysicalGroupBy :: new_single ( vec ! [ ] ) ,
4829+ vec ! [ Arc :: new(
4830+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "a" , & schema) ?] )
4831+ . schema( Arc :: clone( & schema) )
4832+ . alias( "min_a" )
4833+ . build( ) ?,
4834+ ) ] ,
4835+ vec ! [ None ] ,
4836+ child,
4837+ Arc :: clone ( & schema) ,
4838+ ) ?;
4839+ let original_inner_id = agg
4840+ . dynamic_filter ( )
4841+ . expect ( "should have dynamic filter after init" )
4842+ . expression_id ( )
4843+ . expect ( "DynamicFilterPhysicalExpr always has an expression_id" ) ;
4844+
4845+ let new_df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
4846+ vec ! [ col( "a" , & schema) ?] ,
4847+ lit ( true ) ,
4848+ ) ) ;
4849+ let agg = agg. with_dynamic_filter ( Arc :: clone ( & new_df) ) ?;
4850+ let restored = agg
4851+ . dynamic_filter ( )
4852+ . expect ( "should still have dynamic filter" ) ;
4853+ assert_eq ! (
4854+ restored
4855+ . expression_id( )
4856+ . expect( "DynamicFilterPhysicalExpr always has an expression_id" ) ,
4857+ new_df
4858+ . expression_id( )
4859+ . expect( "DynamicFilterPhysicalExpr always has an expression_id" ) ,
4860+ ) ;
4861+ assert_ne ! (
4862+ restored
4863+ . expression_id( )
4864+ . expect( "DynamicFilterPhysicalExpr always has an expression_id" ) ,
4865+ original_inner_id,
4866+ ) ;
4867+ Ok ( ( ) )
4868+ }
4869+
4870+ #[ test]
4871+ fn test_with_dynamic_filter_noop_when_unsupported ( ) -> Result < ( ) > {
4872+ let schema = Arc :: new ( Schema :: new ( vec ! [
4873+ Field :: new( "a" , DataType :: Int64 , false ) ,
4874+ Field :: new( "b" , DataType :: Int64 , false ) ,
4875+ ] ) ) ;
4876+ let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
4877+
4878+ // Final mode with a group-by does not support dynamic filters.
4879+ let agg = AggregateExec :: try_new (
4880+ AggregateMode :: Final ,
4881+ PhysicalGroupBy :: new_single ( vec ! [ ( col( "a" , & schema) ?, "a" . to_string( ) ) ] ) ,
4882+ vec ! [ Arc :: new(
4883+ AggregateExprBuilder :: new( sum_udaf( ) , vec![ col( "b" , & schema) ?] )
4884+ . schema( Arc :: clone( & schema) )
4885+ . alias( "sum_b" )
4886+ . build( ) ?,
4887+ ) ] ,
4888+ vec ! [ None ] ,
4889+ child,
4890+ Arc :: clone ( & schema) ,
4891+ ) ?;
4892+ assert ! ( agg. dynamic_filter( ) . is_none( ) ) ;
4893+
4894+ // with_dynamic_filter should be a no-op.
4895+ let df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
4896+ vec ! [ col( "a" , & schema) ?] ,
4897+ lit ( true ) ,
4898+ ) ) ;
4899+ let agg = agg. with_dynamic_filter ( df) ?;
4900+ assert ! ( agg. dynamic_filter( ) . is_none( ) ) ;
4901+ Ok ( ( ) )
4902+ }
4903+
4904+ #[ test]
4905+ fn test_with_dynamic_filter_rejects_invalid_columns ( ) -> Result < ( ) > {
4906+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ) ;
4907+ let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
4908+
4909+ let agg = AggregateExec :: try_new (
4910+ AggregateMode :: Partial ,
4911+ PhysicalGroupBy :: new_single ( vec ! [ ] ) ,
4912+ vec ! [ Arc :: new(
4913+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "a" , & schema) ?] )
4914+ . schema( Arc :: clone( & schema) )
4915+ . alias( "min_a" )
4916+ . build( ) ?,
4917+ ) ] ,
4918+ vec ! [ None ] ,
4919+ child,
4920+ Arc :: clone ( & schema) ,
4921+ ) ?;
4922+
4923+ // Column index 99 is out of bounds for the input schema.
4924+ let df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
4925+ vec ! [ Arc :: new( Column :: new( "bad" , 99 ) ) as _] ,
4926+ lit ( true ) ,
4927+ ) ) ;
4928+ assert ! ( agg. with_dynamic_filter( df) . is_err( ) ) ;
4929+ Ok ( ( ) )
4930+ }
47844931}
0 commit comments