@@ -46,7 +46,8 @@ use arrow_schema::FieldRef;
4646use datafusion_common:: stats:: Precision ;
4747use datafusion_common:: tree_node:: TreeNodeRecursion ;
4848use datafusion_common:: {
49- Constraint , Constraints , Result , ScalarValue , assert_eq_or_internal_err, not_impl_err,
49+ Constraint , Constraints , Result , ScalarValue , assert_eq_or_internal_err,
50+ internal_err, not_impl_err,
5051} ;
5152use datafusion_execution:: TaskContext ;
5253use datafusion_expr:: { Accumulator , Aggregate } ;
@@ -1052,26 +1053,39 @@ impl AggregateExec {
10521053 self . dynamic_filter . as_ref ( ) . map ( |df| & df. filter )
10531054 }
10541055
1055- /// Replace the dynamic filter expression, recomputing any internal state
1056- /// which may depend on the previous dynamic filter.
1057- ///
1058- /// This is a no-op if the aggregate does not support dynamic filtering.
1059- ///
1060- /// If dynamic filtering is supported, this method returns an error if the filter's
1061- /// children reference invalid columns in the aggregate's input schema.
1056+ /// Replace the dynamic filter expression. This method errors if the aggregate does not
1057+ /// support dynamic filtering or if the filter expression is incompatible with this
1058+ /// [`AggregateExec`].
10621059 pub fn with_dynamic_filter (
10631060 mut self ,
10641061 filter : Arc < DynamicFilterPhysicalExpr > ,
10651062 ) -> Result < Self > {
1066- if let Some ( supported_accumulators_info) = self . supported_accumulators_info ( ) {
1067- for child in filter. children ( ) {
1068- child. data_type ( & self . input_schema ) ?;
1063+ // If there is no dynamic filter state initialized via `try_new`, then
1064+ // we can safely assume that the aggregate does not support dynamic filtering.
1065+ let Some ( dyn_filter) = self . dynamic_filter . as_ref ( ) else {
1066+ return internal_err ! ( "Aggregate does not support dynamic filtering" ) ;
1067+ } ;
1068+
1069+ // Validate that the filter is compatible with the aggregation columns.
1070+ let cols = self . cols_for_dynamic_filter ( & dyn_filter. supported_accumulators_info ) ;
1071+ if cols. len ( ) != filter. children ( ) . len ( ) {
1072+ return internal_err ! (
1073+ "Dynamic filter expression is incompatible with aggregate due to mismatched number of columns"
1074+ ) ;
1075+ }
1076+ for ( col, child) in cols. iter ( ) . zip ( filter. children ( ) ) {
1077+ if !col. eq ( child) {
1078+ return internal_err ! (
1079+ "Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}"
1080+ ) ;
10691081 }
1070- self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1071- filter,
1072- supported_accumulators_info,
1073- } ) ) ;
10741082 }
1083+
1084+ // Overwrite our filter
1085+ self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1086+ filter,
1087+ supported_accumulators_info : dyn_filter. supported_accumulators_info . clone ( ) ,
1088+ } ) ) ;
10751089 Ok ( self )
10761090 }
10771091
@@ -1257,40 +1271,27 @@ impl AggregateExec {
12571271 /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field.
12581272 /// - If not supported, `self.dynamic_filter` should be kept `None`
12591273 fn init_dynamic_filter ( & mut self ) {
1260- // Already initialized.
1261- if self . dynamic_filter . is_some ( ) {
1274+ if ( !self . group_by . is_empty ( ) ) || ( self . mode != AggregateMode :: Partial ) {
1275+ debug_assert ! (
1276+ self . dynamic_filter. is_none( ) ,
1277+ "The current operator node does not support dynamic filter"
1278+ ) ;
12621279 return ;
12631280 }
12641281
1265- if let Some ( supported_accumulators_info) = self . supported_accumulators_info ( ) {
1266- // Collect column references for the dynamic filter expression.
1267- let all_cols: Vec < Arc < dyn PhysicalExpr > > = supported_accumulators_info
1268- . iter ( )
1269- . map ( |info| Arc :: clone ( & self . aggr_expr [ info. aggr_index ] . expressions ( ) [ 0 ] ) )
1270- . collect ( ) ;
1271-
1272- self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1273- filter : Arc :: new ( DynamicFilterPhysicalExpr :: new ( all_cols, lit ( true ) ) ) ,
1274- supported_accumulators_info,
1275- } ) ) ;
1276- }
1277- }
1278-
1279- /// Returns the supported accumulator info if this aggregate supports
1280- /// dynamic filtering, or `None` otherwise.
1281- ///
1282- /// Dynamic filtering requires:
1283- /// - `Partial` aggregation mode with no group-by expressions
1284- /// - All aggregate functions are `min` or `max` with a single column arg
1285- fn supported_accumulators_info ( & self ) -> Option < Vec < PerAccumulatorDynFilter > > {
1286- if !self . group_by . is_empty ( ) || !matches ! ( self . mode, AggregateMode :: Partial ) {
1287- return None ;
1282+ // Already initialized.
1283+ if self . dynamic_filter . is_some ( ) {
1284+ return ;
12881285 }
12891286
1290- // Collect supported accumulators.
1291- // It is assumed the order of aggregate expressions are not changed
1292- // from `AggregateExec` to `AggregateStream`.
1287+ // Collect supported accumulators
1288+ // It is assumed the order of aggregate expressions are not changed from `AggregateExec`
1289+ // to `AggregateStream`
12931290 let mut aggr_dyn_filters = Vec :: new ( ) ;
1291+ // All column references in the dynamic filter, used when initializing the dynamic
1292+ // filter, and it's used to decide if this dynamic filter is able to get push
1293+ // through certain node during optimization.
1294+ let mut all_cols: Vec < Arc < dyn PhysicalExpr > > = Vec :: new ( ) ;
12941295 for ( i, aggr_expr) in self . aggr_expr . iter ( ) . enumerate ( ) {
12951296 // 1. Only `min` or `max` aggregate function
12961297 let fun_name = aggr_expr. fun ( ) . name ( ) ;
@@ -1301,13 +1302,14 @@ impl AggregateExec {
13011302 } else if fun_name. eq_ignore_ascii_case ( "max" ) {
13021303 DynamicFilterAggregateType :: Max
13031304 } else {
1304- return None ;
1305+ return ;
13051306 } ;
13061307
13071308 // 2. arg should be only 1 column reference
13081309 if let [ arg] = aggr_expr. expressions ( ) . as_slice ( )
13091310 && arg. is :: < Column > ( )
13101311 {
1312+ all_cols. push ( Arc :: clone ( arg) ) ;
13111313 aggr_dyn_filters. push ( PerAccumulatorDynFilter {
13121314 aggr_type,
13131315 aggr_index : i,
@@ -1316,13 +1318,36 @@ impl AggregateExec {
13161318 }
13171319 }
13181320
1319- if aggr_dyn_filters. is_empty ( ) {
1320- None
1321- } else {
1322- Some ( aggr_dyn_filters)
1321+ if !aggr_dyn_filters. is_empty ( ) {
1322+ self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1323+ filter : Arc :: new ( DynamicFilterPhysicalExpr :: new ( all_cols, lit ( true ) ) ) ,
1324+ supported_accumulators_info : aggr_dyn_filters,
1325+ } ) )
13231326 }
13241327 }
13251328
1329+ // Collect column references for the dynamic filter expression from the supported accumulators.
1330+ fn cols_for_dynamic_filter (
1331+ & self ,
1332+ supported_accumulators_info : & [ PerAccumulatorDynFilter ] ,
1333+ ) -> Vec < Arc < dyn PhysicalExpr > > {
1334+ let all_cols: Vec < Arc < dyn PhysicalExpr > > = supported_accumulators_info
1335+ . iter ( )
1336+ . filter_map ( |info| {
1337+ // This should always be true due to how the supported accumulators
1338+ // are constructed. See `init_dynamic_filter` for more details.
1339+ if let [ arg] = & self . aggr_expr [ info. aggr_index ] . expressions ( ) . as_slice ( )
1340+ && arg. is :: < Column > ( )
1341+ {
1342+ return Some ( Arc :: clone ( arg) ) ;
1343+ }
1344+ None
1345+ } )
1346+ . collect ( ) ;
1347+ debug_assert ! ( all_cols. len( ) == supported_accumulators_info. len( ) ) ;
1348+ all_cols
1349+ }
1350+
13261351 /// Calculate scaled byte size based on row count ratio.
13271352 /// Returns `Precision::Absent` if input statistics are insufficient.
13281353 /// Returns `Precision::Inexact` with the scaled value otherwise.
@@ -4817,12 +4842,13 @@ mod tests {
48174842 Ok ( ( ) )
48184843 }
48194844
4845+ /// Test that [`AggregateExec::with_dynamic_filter`] overrides the existing dynamic filter
48204846 #[ test]
48214847 fn test_with_dynamic_filter ( ) -> Result < ( ) > {
48224848 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ) ;
48234849 let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
48244850
4825- // Partial min triggers init_dynamic_filter.
4851+ // Partial min aggregate supports dynamic filtering
48264852 let agg = AggregateExec :: try_new (
48274853 AggregateMode :: Partial ,
48284854 PhysicalGroupBy :: new_single ( vec ! [ ] ) ,
@@ -4836,39 +4862,40 @@ mod tests {
48364862 child,
48374863 Arc :: clone ( & schema) ,
48384864 ) ?;
4839- let original_inner_id = agg
4840- . dynamic_filter ( )
4841- . expect ( "should have dynamic filter after init" )
4842- . expression_id ( )
4843- . expect ( "DynamicFilterPhysicalExpr always has an expression_id" ) ;
48444865
4866+ // Assertion 1: A filter with the same children can override the existing
4867+ // dynamic filter.
48454868 let new_df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
48464869 vec ! [ col( "a" , & schema) ?] ,
4847- lit ( true ) ,
4870+ lit ( false ) ,
48484871 ) ) ;
48494872 let agg = agg. with_dynamic_filter ( Arc :: clone ( & new_df) ) ?;
4850- let restored = agg
4873+
4874+ // The aggregate's filter should now resolve to the new inner expression.
4875+ let swapped = agg
48514876 . dynamic_filter ( )
4852- . expect ( "should still have dynamic filter" ) ;
4853- assert_eq ! (
4854- restored
4855- . expression_id( )
4856- . expect( "DynamicFilterPhysicalExpr always has an expression_id" ) ,
4857- new_df
4858- . expression_id( )
4859- . expect( "DynamicFilterPhysicalExpr always has an expression_id" ) ,
4860- ) ;
4861- assert_ne ! (
4862- restored
4863- . expression_id( )
4864- . expect( "DynamicFilterPhysicalExpr always has an expression_id" ) ,
4865- original_inner_id,
4866- ) ;
4877+ . expect ( "should still have dynamic filter" )
4878+ . current ( ) ?;
4879+ assert_eq ! ( format!( "{swapped}" ) , format!( "{}" , lit( false ) ) ) ;
4880+
4881+ // Assertion 2: A filter that has been through `PhysicalExpr::with_new_children`
4882+ // should still be accepted when the new children are equivalent to the originals.
4883+ let new_df_as_pexpr: Arc < dyn PhysicalExpr > = new_df. clone ( ) ;
4884+ let remapped_pexpr =
4885+ new_df_as_pexpr. with_new_children ( vec ! [ col( "a" , & schema) ?] ) ?;
4886+ let remapped_df = ( remapped_pexpr as Arc < dyn std:: any:: Any + Send + Sync > )
4887+ . downcast :: < DynamicFilterPhysicalExpr > ( )
4888+ . ok ( )
4889+ . expect ( "should be DynamicFilterPhysicalExpr after with_new_children" ) ;
4890+ // Hard to assert this because the filter is identical. No error means
4891+ // the filter was accepted. That's a good enough assertion for now.
4892+ let _agg = agg. with_dynamic_filter ( remapped_df) ?;
48674893 Ok ( ( ) )
48684894 }
48694895
4896+ /// Test that [`AggregateExec::with_dynamic_filter`] errors when the aggregate does not support dynamic filtering
48704897 #[ test]
4871- fn test_with_dynamic_filter_noop_when_unsupported ( ) -> Result < ( ) > {
4898+ fn test_with_dynamic_filter_error_unsupported ( ) -> Result < ( ) > {
48724899 let schema = Arc :: new ( Schema :: new ( vec ! [
48734900 Field :: new( "a" , DataType :: Int64 , false ) ,
48744901 Field :: new( "b" , DataType :: Int64 , false ) ,
@@ -4891,18 +4918,17 @@ mod tests {
48914918 ) ?;
48924919 assert ! ( agg. dynamic_filter( ) . is_none( ) ) ;
48934920
4894- // with_dynamic_filter should be a no-op.
48954921 let df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
48964922 vec ! [ col( "a" , & schema) ?] ,
48974923 lit ( true ) ,
48984924 ) ) ;
4899- let agg = agg. with_dynamic_filter ( df) ?;
4900- assert ! ( agg. dynamic_filter( ) . is_none( ) ) ;
4925+ assert ! ( agg. with_dynamic_filter( df) . is_err( ) ) ;
49014926 Ok ( ( ) )
49024927 }
49034928
4929+ /// Test that [`AggregateExec::with_dynamic_filter`] errors when the column is not in the schema
49044930 #[ test]
4905- fn test_with_dynamic_filter_rejects_invalid_columns ( ) -> Result < ( ) > {
4931+ fn test_with_dynamic_filter_error_column_mismatch ( ) -> Result < ( ) > {
49064932 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ) ;
49074933 let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
49084934
@@ -4920,7 +4946,6 @@ mod tests {
49204946 Arc :: clone ( & schema) ,
49214947 ) ?;
49224948
4923- // Column index 99 is out of bounds for the input schema.
49244949 let df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
49254950 vec ! [ Arc :: new( Column :: new( "bad" , 99 ) ) as _] ,
49264951 lit ( true ) ,
0 commit comments