Skip to content

Commit 2da82cd

Browse files
committed
Consolidate plan-wrapper traversal and optimizations
Combine plan-wrapper traversal and cross-join shape detection. Shorten join-column replacement scan and share authoritative null-result decoding. Remove unused helpers and reorganize strict-null operator list behind a classifier helper. Public interfaces remain unchanged.
1 parent e74f5da commit 2da82cd

4 files changed

Lines changed: 146 additions & 131 deletions

File tree

datafusion/core/benches/sql_planner_extended.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,12 @@ fn build_non_case_left_join_df_with_push_down_filter(
357357
fn find_filter_predicates(plan: &LogicalPlan) -> Vec<datafusion_expr::Expr> {
358358
match plan {
359359
LogicalPlan::Filter(filter) => split_conjunction_owned(filter.predicate.clone()),
360-
LogicalPlan::Projection(projection) => find_filter_predicates(projection.input.as_ref()),
361-
other => panic!("expected benchmark query plan to contain a Filter, found {other:?}"),
360+
LogicalPlan::Projection(projection) => {
361+
find_filter_predicates(projection.input.as_ref())
362+
}
363+
other => {
364+
panic!("expected benchmark query plan to contain a Filter, found {other:?}")
365+
}
362366
}
363367
}
364368

@@ -375,7 +379,8 @@ fn assert_case_heavy_left_join_inference_candidates(
375379
for predicate in predicates {
376380
let column_refs = predicate.column_refs();
377381
assert!(
378-
column_refs.contains(&&left_join_key) || column_refs.contains(&&right_join_key),
382+
column_refs.contains(&&left_join_key)
383+
|| column_refs.contains(&&right_join_key),
379384
"benchmark predicate should reference a join key: {predicate}"
380385
);
381386
assert!(

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -285,33 +285,39 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
285285
Ok(is_evaluate)
286286
}
287287

288-
fn strip_aliases_and_projections(plan: &LogicalPlan) -> &LogicalPlan {
288+
fn strip_plan_wrappers(plan: &LogicalPlan) -> (&LogicalPlan, bool) {
289289
match plan {
290290
LogicalPlan::SubqueryAlias(subquery_alias) => {
291-
strip_aliases_and_projections(subquery_alias.input.as_ref())
291+
let (plan, _) = strip_plan_wrappers(subquery_alias.input.as_ref());
292+
(plan, true)
292293
}
293294
LogicalPlan::Projection(projection) => {
294-
strip_aliases_and_projections(projection.input.as_ref())
295+
let (plan, is_derived_relation) =
296+
strip_plan_wrappers(projection.input.as_ref());
297+
(plan, is_derived_relation)
295298
}
296-
_ => plan,
299+
_ => (plan, false),
297300
}
298301
}
299302

300303
fn is_scalar_aggregate_subquery(plan: &LogicalPlan) -> bool {
301304
matches!(
302-
strip_aliases_and_projections(plan),
305+
strip_plan_wrappers(plan).0,
303306
LogicalPlan::Aggregate(aggregate) if aggregate.group_expr.is_empty()
304307
)
305308
}
306309

307310
fn is_derived_relation(plan: &LogicalPlan) -> bool {
308-
match plan {
309-
LogicalPlan::SubqueryAlias(_) => true,
310-
LogicalPlan::Projection(projection) => {
311-
is_derived_relation(projection.input.as_ref())
312-
}
313-
_ => false,
314-
}
311+
strip_plan_wrappers(plan).1
312+
}
313+
314+
fn is_scalar_subquery_cross_join(join: &Join) -> bool {
315+
join.on.is_empty()
316+
&& join.filter.is_none()
317+
&& ((is_scalar_aggregate_subquery(join.left.as_ref())
318+
&& is_derived_relation(join.right.as_ref()))
319+
|| (is_scalar_aggregate_subquery(join.right.as_ref())
320+
&& is_derived_relation(join.left.as_ref())))
315321
}
316322

317323
// Keep post-join filters above certain scalar-subquery cross joins to preserve
@@ -320,19 +326,12 @@ fn should_keep_filter_above_scalar_subquery_cross_join(
320326
join: &Join,
321327
predicate: &Expr,
322328
) -> bool {
323-
if !join.on.is_empty() || join.filter.is_some() {
329+
if !is_scalar_subquery_cross_join(join) {
324330
return false;
325331
}
326332

327333
let mut checker = ColumnChecker::new(join.left.schema(), join.right.schema());
328-
let references_both_sides =
329-
!checker.is_left_only(predicate) && !checker.is_right_only(predicate);
330-
331-
references_both_sides
332-
&& ((is_scalar_aggregate_subquery(join.left.as_ref())
333-
&& is_derived_relation(join.right.as_ref()))
334-
|| (is_scalar_aggregate_subquery(join.right.as_ref())
335-
&& is_derived_relation(join.left.as_ref())))
334+
!checker.is_left_only(predicate) && !checker.is_right_only(predicate)
336335
}
337336

338337
/// examine OR clause to see if any useful clauses can be extracted and push down.
@@ -777,37 +776,31 @@ fn infer_join_predicates_impl<
777776
) -> Result<()> {
778777
for predicate in input_predicates {
779778
let column_refs = predicate.column_refs();
780-
let mut join_col_replacements = Vec::new();
781-
let mut has_non_replaceable_refs = false;
782-
783-
for &col in &column_refs {
784-
let mut replacement = None;
785-
786-
for (l, r) in join_col_keys.iter() {
787-
if ENABLE_LEFT_TO_RIGHT && col == *l {
788-
replacement = Some((col, *r));
789-
break;
790-
}
791-
if ENABLE_RIGHT_TO_LEFT && col == *r {
792-
replacement = Some((col, *l));
793-
break;
794-
}
795-
}
779+
let join_col_replacements: Vec<_> = column_refs
780+
.iter()
781+
.filter_map(|&col| {
782+
join_col_keys.iter().find_map(|(l, r)| {
783+
if ENABLE_LEFT_TO_RIGHT && col == *l {
784+
Some((col, *r))
785+
} else if ENABLE_RIGHT_TO_LEFT && col == *r {
786+
Some((col, *l))
787+
} else {
788+
None
789+
}
790+
})
791+
})
792+
.collect();
796793

797-
if let Some(replacement) = replacement {
798-
join_col_replacements.push(replacement);
799-
} else {
800-
has_non_replaceable_refs = true;
801-
}
802-
}
803794
if join_col_replacements.is_empty() {
804795
continue;
805796
}
806797

807798
// For non-inner joins, predicates that reference any non-replaceable
808799
// columns cannot be inferred on the other side. Skip the null-restriction
809800
// helper entirely in that common mixed-reference case.
810-
if !inferred_predicates.is_inner_join && has_non_replaceable_refs {
801+
if !inferred_predicates.is_inner_join
802+
&& join_col_replacements.len() != column_refs.len()
803+
{
811804
continue;
812805
}
813806

datafusion/optimizer/src/utils.rs

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,14 @@ pub fn is_restrict_null_predicate<'a>(
130130
// Collect join columns so they can be used in both the fast-path check and the
131131
// fallback evaluation path below.
132132
let join_cols: HashSet<&Column> = join_cols_of_predicate.into_iter().collect();
133+
let column_refs = predicate.column_refs();
133134

134135
// Fast path: if the predicate references columns outside the join key set,
135136
// `evaluate_expr_with_null_column` would fail because the null schema only
136137
// contains a placeholder for the join key columns. Callers treat such errors as
137138
// non-restricting (false) via `matches!(_, Ok(true))`, so we return false early
138139
// and avoid the expensive physical-expression compilation pipeline entirely.
139-
if !null_restriction::predicate_uses_only_columns(&predicate, &join_cols) {
140+
if !null_restriction::all_columns_allowed(&column_refs, &join_cols) {
140141
return Ok(false);
141142
}
142143

@@ -180,12 +181,10 @@ pub fn evaluates_to_null<'a>(
180181
return Ok(true);
181182
}
182183

183-
Ok(
184-
match evaluate_expr_with_null_column(predicate, null_columns)? {
185-
ColumnarValue::Array(_) => false,
186-
ColumnarValue::Scalar(scalar) => scalar.is_null(),
187-
},
188-
)
184+
Ok(authoritative_null_result(evaluate_expr_with_null_column(
185+
predicate,
186+
null_columns,
187+
)?)? == AuthoritativeNullResult::AlwaysNull)
189188
}
190189

191190
fn evaluate_expr_with_null_column<'a>(
@@ -219,22 +218,41 @@ fn authoritative_restrict_null_predicate<'a>(
219218
predicate: Expr,
220219
join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
221220
) -> Result<bool> {
222-
Ok(
223-
match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? {
224-
ColumnarValue::Array(array) => {
225-
if array.len() == 1 {
226-
let boolean_array = as_boolean_array(&array)?;
227-
boolean_array.is_null(0) || !boolean_array.value(0)
228-
} else {
229-
false
230-
}
221+
Ok(authoritative_null_result(evaluate_expr_with_null_column(
222+
predicate,
223+
join_cols_of_predicate,
224+
)?)? == AuthoritativeNullResult::NullRestricting)
225+
}
226+
227+
#[derive(Debug, PartialEq, Eq)]
228+
enum AuthoritativeNullResult {
229+
AlwaysNull,
230+
NullRestricting,
231+
Other,
232+
}
233+
234+
fn authoritative_null_result(value: ColumnarValue) -> Result<AuthoritativeNullResult> {
235+
Ok(match value {
236+
ColumnarValue::Array(array) => {
237+
if array.len() != 1 {
238+
return Ok(AuthoritativeNullResult::Other);
231239
}
232-
ColumnarValue::Scalar(scalar) => matches!(
233-
scalar,
234-
ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
235-
),
236-
},
237-
)
240+
241+
let boolean_array = as_boolean_array(&array)?;
242+
if boolean_array.is_null(0) || !boolean_array.value(0) {
243+
AuthoritativeNullResult::NullRestricting
244+
} else {
245+
AuthoritativeNullResult::Other
246+
}
247+
}
248+
ColumnarValue::Scalar(scalar) if scalar.is_null() => {
249+
AuthoritativeNullResult::AlwaysNull
250+
}
251+
ColumnarValue::Scalar(
252+
ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)),
253+
) => AuthoritativeNullResult::NullRestricting,
254+
ColumnarValue::Scalar(_) => AuthoritativeNullResult::Other,
255+
})
238256
}
239257

240258
fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {

0 commit comments

Comments
 (0)