Skip to content

Commit 0f5ea55

Browse files
wip
1 parent c3aef20 commit 0f5ea55

2 files changed

Lines changed: 103 additions & 78 deletions

File tree

  • datafusion/physical-plan/src

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 102 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ use arrow_schema::FieldRef;
4646
use datafusion_common::stats::Precision;
4747
use datafusion_common::tree_node::TreeNodeRecursion;
4848
use datafusion_common::{
49-
Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err,
49+
Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err,
50+
internal_err, not_impl_err,
5051
};
5152
use datafusion_execution::TaskContext;
5253
use datafusion_expr::{Accumulator, Aggregate};
@@ -1052,26 +1053,39 @@ impl AggregateExec {
10521053
self.dynamic_filter.as_ref().map(|df| &df.filter)
10531054
}
10541055

1055-
/// Replace the dynamic filter expression, recomputing any internal state
1056-
/// which may depend on the previous dynamic filter.
1057-
///
1058-
/// This is a no-op if the aggregate does not support dynamic filtering.
1059-
///
1060-
/// If dynamic filtering is supported, this method returns an error if the filter's
1061-
/// children reference invalid columns in the aggregate's input schema.
1056+
/// Replace the dynamic filter expression. This method errors if the aggregate does not
1057+
/// support dynamic filtering or if the filter expression is incompatible with this
1058+
/// [`AggregateExec`].
10621059
pub fn with_dynamic_filter(
10631060
mut self,
10641061
filter: Arc<DynamicFilterPhysicalExpr>,
10651062
) -> Result<Self> {
1066-
if let Some(supported_accumulators_info) = self.supported_accumulators_info() {
1067-
for child in filter.children() {
1068-
child.data_type(&self.input_schema)?;
1063+
// If there is no dynamic filter state initialized via `try_new`, then
1064+
// we can safely assume that the aggregate does not support dynamic filtering.
1065+
let Some(dyn_filter) = self.dynamic_filter.as_ref() else {
1066+
return internal_err!("Aggregate does not support dynamic filtering");
1067+
};
1068+
1069+
// Validate that the filter is compatible with the aggregation columns.
1070+
let cols = self.cols_for_dynamic_filter(&dyn_filter.supported_accumulators_info);
1071+
if cols.len() != filter.children().len() {
1072+
return internal_err!(
1073+
"Dynamic filter expression is incompatible with aggregate due to mismatched number of columns"
1074+
);
1075+
}
1076+
for (col, child) in cols.iter().zip(filter.children()) {
1077+
if !col.eq(child) {
1078+
return internal_err!(
1079+
"Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}"
1080+
);
10691081
}
1070-
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1071-
filter,
1072-
supported_accumulators_info,
1073-
}));
10741082
}
1083+
1084+
// Overwrite our filter
1085+
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1086+
filter,
1087+
supported_accumulators_info: dyn_filter.supported_accumulators_info.clone(),
1088+
}));
10751089
Ok(self)
10761090
}
10771091

@@ -1257,40 +1271,27 @@ impl AggregateExec {
12571271
/// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field.
12581272
/// - If not supported, `self.dynamic_filter` should be kept `None`
12591273
fn init_dynamic_filter(&mut self) {
1260-
// Already initialized.
1261-
if self.dynamic_filter.is_some() {
1274+
if (!self.group_by.is_empty()) || (self.mode != AggregateMode::Partial) {
1275+
debug_assert!(
1276+
self.dynamic_filter.is_none(),
1277+
"The current operator node does not support dynamic filter"
1278+
);
12621279
return;
12631280
}
12641281

1265-
if let Some(supported_accumulators_info) = self.supported_accumulators_info() {
1266-
// Collect column references for the dynamic filter expression.
1267-
let all_cols: Vec<Arc<dyn PhysicalExpr>> = supported_accumulators_info
1268-
.iter()
1269-
.map(|info| Arc::clone(&self.aggr_expr[info.aggr_index].expressions()[0]))
1270-
.collect();
1271-
1272-
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1273-
filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1274-
supported_accumulators_info,
1275-
}));
1276-
}
1277-
}
1278-
1279-
/// Returns the supported accumulator info if this aggregate supports
1280-
/// dynamic filtering, or `None` otherwise.
1281-
///
1282-
/// Dynamic filtering requires:
1283-
/// - `Partial` aggregation mode with no group-by expressions
1284-
/// - All aggregate functions are `min` or `max` with a single column arg
1285-
fn supported_accumulators_info(&self) -> Option<Vec<PerAccumulatorDynFilter>> {
1286-
if !self.group_by.is_empty() || !matches!(self.mode, AggregateMode::Partial) {
1287-
return None;
1282+
// Already initialized.
1283+
if self.dynamic_filter.is_some() {
1284+
return;
12881285
}
12891286

1290-
// Collect supported accumulators.
1291-
// It is assumed the order of aggregate expressions are not changed
1292-
// from `AggregateExec` to `AggregateStream`.
1287+
// Collect supported accumulators
1288+
// It is assumed the order of aggregate expressions are not changed from `AggregateExec`
1289+
// to `AggregateStream`
12931290
let mut aggr_dyn_filters = Vec::new();
1291+
// All column references in the dynamic filter, used when initializing the dynamic
1292+
// filter, and it's used to decide if this dynamic filter is able to get push
1293+
// through certain node during optimization.
1294+
let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
12941295
for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
12951296
// 1. Only `min` or `max` aggregate function
12961297
let fun_name = aggr_expr.fun().name();
@@ -1301,13 +1302,14 @@ impl AggregateExec {
13011302
} else if fun_name.eq_ignore_ascii_case("max") {
13021303
DynamicFilterAggregateType::Max
13031304
} else {
1304-
return None;
1305+
return;
13051306
};
13061307

13071308
// 2. arg should be only 1 column reference
13081309
if let [arg] = aggr_expr.expressions().as_slice()
13091310
&& arg.is::<Column>()
13101311
{
1312+
all_cols.push(Arc::clone(arg));
13111313
aggr_dyn_filters.push(PerAccumulatorDynFilter {
13121314
aggr_type,
13131315
aggr_index: i,
@@ -1316,13 +1318,36 @@ impl AggregateExec {
13161318
}
13171319
}
13181320

1319-
if aggr_dyn_filters.is_empty() {
1320-
None
1321-
} else {
1322-
Some(aggr_dyn_filters)
1321+
if !aggr_dyn_filters.is_empty() {
1322+
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1323+
filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1324+
supported_accumulators_info: aggr_dyn_filters,
1325+
}))
13231326
}
13241327
}
13251328

1329+
// Collect column references for the dynamic filter expression from the supported accumulators.
1330+
fn cols_for_dynamic_filter(
1331+
&self,
1332+
supported_accumulators_info: &[PerAccumulatorDynFilter],
1333+
) -> Vec<Arc<dyn PhysicalExpr>> {
1334+
let all_cols: Vec<Arc<dyn PhysicalExpr>> = supported_accumulators_info
1335+
.iter()
1336+
.filter_map(|info| {
1337+
// This should always be true due to how the supported accumulators
1338+
// are constructed. See `init_dynamic_filter` for more details.
1339+
if let [arg] = &self.aggr_expr[info.aggr_index].expressions().as_slice()
1340+
&& arg.is::<Column>()
1341+
{
1342+
return Some(Arc::clone(arg));
1343+
}
1344+
None
1345+
})
1346+
.collect();
1347+
debug_assert!(all_cols.len() == supported_accumulators_info.len());
1348+
all_cols
1349+
}
1350+
13261351
/// Calculate scaled byte size based on row count ratio.
13271352
/// Returns `Precision::Absent` if input statistics are insufficient.
13281353
/// Returns `Precision::Inexact` with the scaled value otherwise.
@@ -4817,12 +4842,13 @@ mod tests {
48174842
Ok(())
48184843
}
48194844

4845+
/// Test that [`AggregateExec::with_dynamic_filter`] overrides the existing dynamic filter
48204846
#[test]
48214847
fn test_with_dynamic_filter() -> Result<()> {
48224848
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
48234849
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
48244850

4825-
// Partial min triggers init_dynamic_filter.
4851+
// Partial min aggregate supports dynamic filtering
48264852
let agg = AggregateExec::try_new(
48274853
AggregateMode::Partial,
48284854
PhysicalGroupBy::new_single(vec![]),
@@ -4836,39 +4862,40 @@ mod tests {
48364862
child,
48374863
Arc::clone(&schema),
48384864
)?;
4839-
let original_inner_id = agg
4840-
.dynamic_filter()
4841-
.expect("should have dynamic filter after init")
4842-
.expression_id()
4843-
.expect("DynamicFilterPhysicalExpr always has an expression_id");
48444865

4866+
// Assertion 1: A filter with the same children can override the existing
4867+
// dynamic filter.
48454868
let new_df = Arc::new(DynamicFilterPhysicalExpr::new(
48464869
vec![col("a", &schema)?],
4847-
lit(true),
4870+
lit(false),
48484871
));
48494872
let agg = agg.with_dynamic_filter(Arc::clone(&new_df))?;
4850-
let restored = agg
4873+
4874+
// The aggregate's filter should now resolve to the new inner expression.
4875+
let swapped = agg
48514876
.dynamic_filter()
4852-
.expect("should still have dynamic filter");
4853-
assert_eq!(
4854-
restored
4855-
.expression_id()
4856-
.expect("DynamicFilterPhysicalExpr always has an expression_id"),
4857-
new_df
4858-
.expression_id()
4859-
.expect("DynamicFilterPhysicalExpr always has an expression_id"),
4860-
);
4861-
assert_ne!(
4862-
restored
4863-
.expression_id()
4864-
.expect("DynamicFilterPhysicalExpr always has an expression_id"),
4865-
original_inner_id,
4866-
);
4877+
.expect("should still have dynamic filter")
4878+
.current()?;
4879+
assert_eq!(format!("{swapped}"), format!("{}", lit(false)));
4880+
4881+
// Assertion 2: A filter that has been through `PhysicalExpr::with_new_children`
4882+
// should still be accepted when the new children are equivalent to the originals.
4883+
let new_df_as_pexpr: Arc<dyn PhysicalExpr> = new_df.clone();
4884+
let remapped_pexpr =
4885+
new_df_as_pexpr.with_new_children(vec![col("a", &schema)?])?;
4886+
let remapped_df = (remapped_pexpr as Arc<dyn std::any::Any + Send + Sync>)
4887+
.downcast::<DynamicFilterPhysicalExpr>()
4888+
.ok()
4889+
.expect("should be DynamicFilterPhysicalExpr after with_new_children");
4890+
// Hard to assert this because the filter is identical. No error means
4891+
// the filter was accepted. That's a good enough assertion for now.
4892+
let _agg = agg.with_dynamic_filter(remapped_df)?;
48674893
Ok(())
48684894
}
48694895

4896+
/// Test that [`AggregateExec::with_dynamic_filter`] errors when the aggregate does not support dynamic filtering
48704897
#[test]
4871-
fn test_with_dynamic_filter_noop_when_unsupported() -> Result<()> {
4898+
fn test_with_dynamic_filter_error_unsupported() -> Result<()> {
48724899
let schema = Arc::new(Schema::new(vec![
48734900
Field::new("a", DataType::Int64, false),
48744901
Field::new("b", DataType::Int64, false),
@@ -4891,18 +4918,17 @@ mod tests {
48914918
)?;
48924919
assert!(agg.dynamic_filter().is_none());
48934920

4894-
// with_dynamic_filter should be a no-op.
48954921
let df = Arc::new(DynamicFilterPhysicalExpr::new(
48964922
vec![col("a", &schema)?],
48974923
lit(true),
48984924
));
4899-
let agg = agg.with_dynamic_filter(df)?;
4900-
assert!(agg.dynamic_filter().is_none());
4925+
assert!(agg.with_dynamic_filter(df).is_err());
49014926
Ok(())
49024927
}
49034928

4929+
/// Test that [`AggregateExec::with_dynamic_filter`] errors when the column is not in the schema
49044930
#[test]
4905-
fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> {
4931+
fn test_with_dynamic_filter_error_column_mismatch() -> Result<()> {
49064932
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
49074933
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
49084934

@@ -4920,7 +4946,6 @@ mod tests {
49204946
Arc::clone(&schema),
49214947
)?;
49224948

4923-
// Column index 99 is out of bounds for the input schema.
49244949
let df = Arc::new(DynamicFilterPhysicalExpr::new(
49254950
vec![Arc::new(Column::new("bad", 99)) as _],
49264951
lit(true),

datafusion/physical-plan/src/joins/hash_join/exec.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ impl HashJoinExec {
910910

911911
/// Set the dynamic filter on this hash join.
912912
///
913-
/// Resets any internal state that depends on any previous dynamic filter.
913+
/// Resets any internal state that depends on any existing dynamic filter.
914914
///
915915
/// Validates that the filter's children reference valid columns in
916916
/// the probe (right) side's schema.

0 commit comments

Comments
 (0)