Skip to content

Commit c3aef20

Browse files
proto: serialize dynamic filters on Sort, Aggregate, HashJoin
Builds on the prior `DynamicFilterPhysicalExpr` proto serialization + dedupe work so plan-node references to a shared dynamic filter survive roundtrip. - Adds `dynamic_filter` to the proto messages for `SortExec`, `AggregateExec`, and `HashJoinExec` and wires it through to/from-proto. - Exposes `dynamic_filter()` / `with_dynamic_filter()` on those plan nodes so the dedupe deserializer can reattach the shared `DynamicFilterPhysicalExpr` after roundtrip. - Extracts `supported_accumulators_info()` on `AggregateExec` and uses it from `init_dynamic_filter` and `with_dynamic_filter`. - Adds `test_hash_join_with_dynamic_filter_roundtrip`, `test_aggregate_with_dynamic_filter_roundtrip`, and `test_sort_topk_with_dynamic_filter_roundtrip` to verify that the plan node and the pushdown-target `ParquetSource` predicate end up pointing at the same `expression_id` after roundtrip. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ba038e9 commit c3aef20

10 files changed

Lines changed: 780 additions & 74 deletions

File tree

datafusion/core/tests/physical_optimizer/filter_pushdown.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2835,7 +2835,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() {
28352835

28362836
// Verify that a dynamic filter was created
28372837
let dynamic_filter = hash_join
2838-
.dynamic_filter_for_test()
2838+
.dynamic_filter()
28392839
.expect("Dynamic filter should be created");
28402840

28412841
// Verify that is_used() returns the expected value based on probe side support.

datafusion/physical-expr-common/src/physical_expr.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
7272
/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
7373
/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
7474
pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
75-
/// Get the data type of this expression, given the schema of the input
75+
/// Get the data type of this expression, given the schema of the input.
76+
/// Returns an error if the data type cannot be determined, ex. if the
77+
/// schema is missing a required field.
7678
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
7779
Ok(self.return_field(input_schema)?.data_type().to_owned())
7880
}

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 183 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,34 @@ impl AggregateExec {
10471047
&self.input_order_mode
10481048
}
10491049

1050+
/// Returns the dynamic filter expression for this aggregate, if set.
1051+
pub fn dynamic_filter(&self) -> Option<&Arc<DynamicFilterPhysicalExpr>> {
1052+
self.dynamic_filter.as_ref().map(|df| &df.filter)
1053+
}
1054+
1055+
/// Replace the dynamic filter expression, recomputing any internal state
1056+
/// which may depend on the previous dynamic filter.
1057+
///
1058+
/// This is a no-op if the aggregate does not support dynamic filtering.
1059+
///
1060+
/// If dynamic filtering is supported, this method returns an error if the filter's
1061+
/// children reference invalid columns in the aggregate's input schema.
1062+
pub fn with_dynamic_filter(
1063+
mut self,
1064+
filter: Arc<DynamicFilterPhysicalExpr>,
1065+
) -> Result<Self> {
1066+
if let Some(supported_accumulators_info) = self.supported_accumulators_info() {
1067+
for child in filter.children() {
1068+
child.data_type(&self.input_schema)?;
1069+
}
1070+
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1071+
filter,
1072+
supported_accumulators_info,
1073+
}));
1074+
}
1075+
Ok(self)
1076+
}
1077+
10501078
/// Estimates output statistics for this aggregate node.
10511079
///
10521080
/// For grouped aggregations with known input row count > 1, the output row
@@ -1229,27 +1257,40 @@ impl AggregateExec {
12291257
/// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field.
12301258
/// - If not supported, `self.dynamic_filter` should be kept `None`
12311259
fn init_dynamic_filter(&mut self) {
1232-
if (!self.group_by.is_empty()) || (self.mode != AggregateMode::Partial) {
1233-
debug_assert!(
1234-
self.dynamic_filter.is_none(),
1235-
"The current operator node does not support dynamic filter"
1236-
);
1237-
return;
1238-
}
1239-
12401260
// Already initialized.
12411261
if self.dynamic_filter.is_some() {
12421262
return;
12431263
}
12441264

1245-
// Collect supported accumulators
1246-
// It is assumed the order of aggregate expressions are not changed from `AggregateExec`
1247-
// to `AggregateStream`
1265+
if let Some(supported_accumulators_info) = self.supported_accumulators_info() {
1266+
// Collect column references for the dynamic filter expression.
1267+
let all_cols: Vec<Arc<dyn PhysicalExpr>> = supported_accumulators_info
1268+
.iter()
1269+
.map(|info| Arc::clone(&self.aggr_expr[info.aggr_index].expressions()[0]))
1270+
.collect();
1271+
1272+
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1273+
filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1274+
supported_accumulators_info,
1275+
}));
1276+
}
1277+
}
1278+
1279+
/// Returns the supported accumulator info if this aggregate supports
1280+
/// dynamic filtering, or `None` otherwise.
1281+
///
1282+
/// Dynamic filtering requires:
1283+
/// - `Partial` aggregation mode with no group-by expressions
1284+
/// - All aggregate functions are `min` or `max` with a single column arg
1285+
fn supported_accumulators_info(&self) -> Option<Vec<PerAccumulatorDynFilter>> {
1286+
if !self.group_by.is_empty() || !matches!(self.mode, AggregateMode::Partial) {
1287+
return None;
1288+
}
1289+
1290+
// Collect supported accumulators.
1291+
// It is assumed the order of aggregate expressions are not changed
1292+
// from `AggregateExec` to `AggregateStream`.
12481293
let mut aggr_dyn_filters = Vec::new();
1249-
// All column references in the dynamic filter, used when initializing the dynamic
1250-
// filter, and it's used to decide if this dynamic filter is able to get push
1251-
// through certain node during optimization.
1252-
let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
12531294
for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
12541295
// 1. Only `min` or `max` aggregate function
12551296
let fun_name = aggr_expr.fun().name();
@@ -1260,14 +1301,13 @@ impl AggregateExec {
12601301
} else if fun_name.eq_ignore_ascii_case("max") {
12611302
DynamicFilterAggregateType::Max
12621303
} else {
1263-
return;
1304+
return None;
12641305
};
12651306

12661307
// 2. arg should be only 1 column reference
12671308
if let [arg] = aggr_expr.expressions().as_slice()
12681309
&& arg.is::<Column>()
12691310
{
1270-
all_cols.push(Arc::clone(arg));
12711311
aggr_dyn_filters.push(PerAccumulatorDynFilter {
12721312
aggr_type,
12731313
aggr_index: i,
@@ -1276,11 +1316,10 @@ impl AggregateExec {
12761316
}
12771317
}
12781318

1279-
if !aggr_dyn_filters.is_empty() {
1280-
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1281-
filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1282-
supported_accumulators_info: aggr_dyn_filters,
1283-
}))
1319+
if aggr_dyn_filters.is_empty() {
1320+
None
1321+
} else {
1322+
Some(aggr_dyn_filters)
12841323
}
12851324
}
12861325

@@ -2177,6 +2216,7 @@ mod tests {
21772216
use crate::coalesce_partitions::CoalescePartitionsExec;
21782217
use crate::common;
21792218
use crate::common::collect;
2219+
use crate::empty::EmptyExec;
21802220
use crate::execution_plan::Boundedness;
21812221
use crate::expressions::col;
21822222
use crate::metrics::MetricValue;
@@ -2202,6 +2242,7 @@ mod tests {
22022242
use datafusion_functions_aggregate::count::count_udaf;
22032243
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
22042244
use datafusion_functions_aggregate::median::median_udaf;
2245+
use datafusion_functions_aggregate::min_max::min_udaf;
22052246
use datafusion_functions_aggregate::sum::sum_udaf;
22062247
use datafusion_physical_expr::Partitioning;
22072248
use datafusion_physical_expr::PhysicalSortExpr;
@@ -3682,13 +3723,10 @@ mod tests {
36823723
// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
36833724
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
36843725
Arc::new(
3685-
AggregateExprBuilder::new(
3686-
datafusion_functions_aggregate::min_max::min_udaf(),
3687-
vec![col("b", &schema)?],
3688-
)
3689-
.schema(Arc::clone(&schema))
3690-
.alias("MIN(b)")
3691-
.build()?,
3726+
AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
3727+
.schema(Arc::clone(&schema))
3728+
.alias("MIN(b)")
3729+
.build()?,
36923730
),
36933731
Arc::new(
36943732
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
@@ -3827,13 +3865,10 @@ mod tests {
38273865
// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
38283866
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
38293867
Arc::new(
3830-
AggregateExprBuilder::new(
3831-
datafusion_functions_aggregate::min_max::min_udaf(),
3832-
vec![col("b", &schema)?],
3833-
)
3834-
.schema(Arc::clone(&schema))
3835-
.alias("MIN(b)")
3836-
.build()?,
3868+
AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
3869+
.schema(Arc::clone(&schema))
3870+
.alias("MIN(b)")
3871+
.build()?,
38373872
),
38383873
Arc::new(
38393874
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
@@ -4781,4 +4816,116 @@ mod tests {
47814816

47824817
Ok(())
47834818
}
4819+
4820+
#[test]
4821+
fn test_with_dynamic_filter() -> Result<()> {
4822+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
4823+
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
4824+
4825+
// Partial min triggers init_dynamic_filter.
4826+
let agg = AggregateExec::try_new(
4827+
AggregateMode::Partial,
4828+
PhysicalGroupBy::new_single(vec![]),
4829+
vec![Arc::new(
4830+
AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
4831+
.schema(Arc::clone(&schema))
4832+
.alias("min_a")
4833+
.build()?,
4834+
)],
4835+
vec![None],
4836+
child,
4837+
Arc::clone(&schema),
4838+
)?;
4839+
let original_inner_id = agg
4840+
.dynamic_filter()
4841+
.expect("should have dynamic filter after init")
4842+
.expression_id()
4843+
.expect("DynamicFilterPhysicalExpr always has an expression_id");
4844+
4845+
let new_df = Arc::new(DynamicFilterPhysicalExpr::new(
4846+
vec![col("a", &schema)?],
4847+
lit(true),
4848+
));
4849+
let agg = agg.with_dynamic_filter(Arc::clone(&new_df))?;
4850+
let restored = agg
4851+
.dynamic_filter()
4852+
.expect("should still have dynamic filter");
4853+
assert_eq!(
4854+
restored
4855+
.expression_id()
4856+
.expect("DynamicFilterPhysicalExpr always has an expression_id"),
4857+
new_df
4858+
.expression_id()
4859+
.expect("DynamicFilterPhysicalExpr always has an expression_id"),
4860+
);
4861+
assert_ne!(
4862+
restored
4863+
.expression_id()
4864+
.expect("DynamicFilterPhysicalExpr always has an expression_id"),
4865+
original_inner_id,
4866+
);
4867+
Ok(())
4868+
}
4869+
4870+
#[test]
4871+
fn test_with_dynamic_filter_noop_when_unsupported() -> Result<()> {
4872+
let schema = Arc::new(Schema::new(vec![
4873+
Field::new("a", DataType::Int64, false),
4874+
Field::new("b", DataType::Int64, false),
4875+
]));
4876+
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
4877+
4878+
// Final mode with a group-by does not support dynamic filters.
4879+
let agg = AggregateExec::try_new(
4880+
AggregateMode::Final,
4881+
PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]),
4882+
vec![Arc::new(
4883+
AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
4884+
.schema(Arc::clone(&schema))
4885+
.alias("sum_b")
4886+
.build()?,
4887+
)],
4888+
vec![None],
4889+
child,
4890+
Arc::clone(&schema),
4891+
)?;
4892+
assert!(agg.dynamic_filter().is_none());
4893+
4894+
// with_dynamic_filter should be a no-op.
4895+
let df = Arc::new(DynamicFilterPhysicalExpr::new(
4896+
vec![col("a", &schema)?],
4897+
lit(true),
4898+
));
4899+
let agg = agg.with_dynamic_filter(df)?;
4900+
assert!(agg.dynamic_filter().is_none());
4901+
Ok(())
4902+
}
4903+
4904+
#[test]
4905+
fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> {
4906+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
4907+
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
4908+
4909+
let agg = AggregateExec::try_new(
4910+
AggregateMode::Partial,
4911+
PhysicalGroupBy::new_single(vec![]),
4912+
vec![Arc::new(
4913+
AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
4914+
.schema(Arc::clone(&schema))
4915+
.alias("min_a")
4916+
.build()?,
4917+
)],
4918+
vec![None],
4919+
child,
4920+
Arc::clone(&schema),
4921+
)?;
4922+
4923+
// Column index 99 is out of bounds for the input schema.
4924+
let df = Arc::new(DynamicFilterPhysicalExpr::new(
4925+
vec![Arc::new(Column::new("bad", 99)) as _],
4926+
lit(true),
4927+
));
4928+
assert!(agg.with_dynamic_filter(df).is_err());
4929+
Ok(())
4930+
}
47844931
}

0 commit comments

Comments
 (0)