Skip to content

Commit 232a7bf

Browse files
committed
step 2
1 parent b943816 commit 232a7bf

8 files changed

Lines changed: 521 additions & 112 deletions

File tree

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3909,13 +3909,13 @@ pub struct Join {
39093909
pub schema: DFSchemaRef,
39103910
/// Defines the null equality for the join.
39113911
pub null_equality: NullEquality,
3912-
/// Whether this is a null-aware anti join (for NOT IN semantics).
3912+
/// Whether this join needs null-aware NOT IN semantics.
39133913
///
3914-
/// Only applies to LeftAnti joins. When true, implements SQL NOT IN semantics where:
3915-
/// - If the right side (subquery) contains any NULL in join keys, no rows are output
3916-
/// - Left side rows with NULL in join keys are not output
3914+
/// For `LeftAnti`, if the right side contains any NULL in join keys, no rows are output and
3915+
/// left rows with NULL join keys are also excluded.
39173916
///
3918-
/// This is required for correct NOT IN subquery behavior with three-valued logic.
3917+
/// For `LeftMark`, the generated `mark` column becomes nullable so unmatched rows can produce
3918+
/// `NULL` rather than `false` when SQL three-valued logic requires it.
39193919
pub null_aware: bool,
39203920
}
39213921

@@ -3934,7 +3934,7 @@ impl Join {
39343934
/// * `join_type` - Type of join (Inner, Left, Right, etc.)
39353935
/// * `join_constraint` - Join constraint (On, Using)
39363936
/// * `null_equality` - How to handle nulls in join comparisons
3937-
/// * `null_aware` - Whether this is a null-aware anti join (for NOT IN semantics)
3937+
/// * `null_aware` - Whether this join needs null-aware NOT IN semantics
39383938
///
39393939
/// # Returns
39403940
///

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ fn build_join(
371371
.values()
372372
.for_each(|cols| all_correlated_cols.extend(cols.clone()));
373373

374+
let has_correlated_join_filter = !pull_up.join_filters.is_empty();
375+
374376
// alias the join filter
375377
let join_filter_opt = conjunction(pull_up.join_filters)
376378
.map_or(Ok(None), |filter| {
@@ -440,9 +442,27 @@ fn build_join(
440442
sub_query_alias.clone()
441443
};
442444

443-
// Mark joins don't use null-aware semantics (they use three-valued logic with mark column)
445+
// For simple uncorrelated NOT IN disjunctions, propagate null-aware semantics into the
446+
// nullable mark column. Correlated mark joins still use the legacy path because the
447+
// runtime state is global to the probe side rather than per-left-row.
448+
let null_aware = join_type == JoinType::LeftMark
449+
&& in_predicate_opt.is_some()
450+
&& !has_correlated_join_filter
451+
&& join_keys_may_be_null(
452+
&join_filter,
453+
left.schema(),
454+
right_projected.schema(),
455+
)?;
456+
444457
let new_plan = LogicalPlanBuilder::from(left.clone())
445-
.join_on(right_projected, join_type, Some(join_filter))?
458+
.join_detailed_with_options(
459+
right_projected,
460+
join_type,
461+
(Vec::<Column>::new(), Vec::<Column>::new()),
462+
Some(join_filter),
463+
NullEquality::NullEqualsNothing,
464+
null_aware,
465+
)?
446466
.build()?;
447467

448468
debug!(
@@ -461,13 +481,9 @@ fn build_join(
461481
//
462482
// Additionally, if the join keys are non-nullable on both sides, we don't need
463483
// null-aware semantics because NULLs cannot exist in the data.
464-
let null_aware = dbg!(matches!(join_type, JoinType::LeftAnti | JoinType::LeftMark))
465-
&& dbg!(in_predicate_opt.is_some())
466-
&& dbg!(join_keys_may_be_null(
467-
&join_filter,
468-
left.schema(),
469-
sub_query_alias.schema()
470-
)?);
484+
let null_aware = matches!(join_type, JoinType::LeftAnti)
485+
&& in_predicate_opt.is_some()
486+
&& join_keys_may_be_null(&join_filter, left.schema(), sub_query_alias.schema())?;
471487

472488
// join our sub query into the main plan
473489
let new_plan = if null_aware {
@@ -1740,8 +1756,8 @@ mod tests {
17401756
plan,
17411757
@r"
17421758
Projection: customer.c_custkey [c_custkey:Int64]
1743-
Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1744-
LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1759+
Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean;N]
1760+
LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean;N]
17451761
TableScan: customer [c_custkey:Int64, c_name:Utf8]
17461762
SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
17471763
Projection: orders.o_custkey [o_custkey:Int64]

datafusion/physical-plan/src/joins/hash_join/exec.rs

Lines changed: 113 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,8 @@ pub(super) struct JoinLeftData {
215215
/// Shared atomic flag indicating if any probe partition saw data (for null-aware anti/mark joins)
216216
/// This is shared across all probe partitions to provide global knowledge
217217
pub(super) probe_side_non_empty: AtomicBool,
218-
/// Shared atomic flag indicating if any probe partition saw NULL in join keys (for null-aware anti joins)
218+
/// Shared atomic flag indicating if any probe partition saw NULL in join keys
219219
pub(super) probe_side_has_null: AtomicBool,
220-
/// Shared atomic flag indicating if any build partition saw NULL in join keys (for null-aware mark joins)
221-
pub(super) build_side_has_nulls: AtomicBool,
222-
/// Not sure how to use this yet
223-
pub(super) build_side_is_empty: AtomicBool,
224220
}
225221

226222
impl JoinLeftData {
@@ -409,15 +405,15 @@ impl HashJoinExecBuilder {
409405
// Validate null_aware flag
410406
if exec.null_aware {
411407
let join_type = exec.join_type();
412-
if !matches!(join_type, JoinType::LeftAnti) {
408+
if !matches!(join_type, JoinType::LeftAnti | JoinType::LeftMark) {
413409
return plan_err!(
414-
"null_aware can only be true for LeftAnti joins, got {join_type}"
410+
"null_aware can only be true for LeftAnti or LeftMark joins, got {join_type}"
415411
);
416412
}
417413
let on = exec.on();
418414
if on.len() != 1 {
419415
return plan_err!(
420-
"null_aware anti join only supports single column join key, got {} columns",
416+
"null_aware joins only support single column join key, got {} columns",
421417
on.len()
422418
);
423419
}
@@ -2079,9 +2075,6 @@ async fn collect_left_input(
20792075
bounds = None;
20802076
}
20812077

2082-
let build_side_has_nulls = batch.columns().iter().any(|col| col.null_count() > 0);
2083-
let build_side_is_empty = batch.num_rows() == 0;
2084-
20852078
let data = JoinLeftData {
20862079
map,
20872080
batch,
@@ -2093,8 +2086,6 @@ async fn collect_left_input(
20932086
membership,
20942087
probe_side_non_empty: AtomicBool::new(false),
20952088
probe_side_has_null: AtomicBool::new(false),
2096-
build_side_has_nulls: AtomicBool::new(build_side_has_nulls),
2097-
build_side_is_empty: AtomicBool::new(build_side_is_empty),
20982089
};
20992090

21002091
Ok(data)
@@ -6067,7 +6058,7 @@ mod tests {
60676058
Ok(())
60686059
}
60696060

6070-
/// Test that null_aware validation rejects non-LeftAnti join types
6061+
/// Test that null_aware validation rejects unsupported join types
60716062
#[tokio::test]
60726063
async fn test_null_aware_validation_wrong_join_type() {
60736064
let left =
@@ -6098,7 +6089,7 @@ mod tests {
60986089
result
60996090
.unwrap_err()
61006091
.to_string()
6101-
.contains("null_aware can only be true for LeftAnti joins")
6092+
.contains("null_aware can only be true for LeftAnti or LeftMark joins")
61026093
);
61036094
}
61046095

@@ -6138,8 +6129,114 @@ mod tests {
61386129
result
61396130
.unwrap_err()
61406131
.to_string()
6141-
.contains("null_aware anti join only supports single column join key")
6132+
.contains("null_aware joins only support single column join key")
6133+
);
6134+
}
6135+
6136+
/// Test null-aware left mark join when probe side contains NULL.
6137+
/// Expected:
6138+
/// - matched rows => true
6139+
/// - unmatched non-NULL rows => NULL
6140+
/// - NULL build keys with non-empty probe side => NULL
6141+
#[apply(hash_join_exec_configs)]
6142+
#[tokio::test]
6143+
async fn test_null_aware_left_mark_probe_null(batch_size: usize) -> Result<()> {
6144+
let task_ctx = prepare_task_ctx(batch_size, false);
6145+
6146+
let left = build_table_two_cols(
6147+
("c1", &vec![Some(1), Some(4), None]),
6148+
("dummy", &vec![Some(10), Some(40), Some(0)]),
6149+
);
6150+
6151+
let right = build_table_two_cols(
6152+
("c2", &vec![Some(1), Some(2), None]),
6153+
("dummy", &vec![Some(100), Some(200), Some(300)]),
6154+
);
6155+
6156+
let on = vec![(
6157+
Arc::new(Column::new_with_schema("c1", &left.schema())?) as _,
6158+
Arc::new(Column::new_with_schema("c2", &right.schema())?) as _,
6159+
)];
6160+
6161+
let join = HashJoinExec::try_new(
6162+
left,
6163+
right,
6164+
on,
6165+
None,
6166+
&JoinType::LeftMark,
6167+
None,
6168+
PartitionMode::CollectLeft,
6169+
NullEquality::NullEqualsNothing,
6170+
true, // null_aware = true
6171+
)?;
6172+
6173+
let stream = join.execute(0, task_ctx)?;
6174+
let batches = common::collect(stream).await?;
6175+
6176+
allow_duplicates! {
6177+
assert_snapshot!(batches_to_sort_string(&batches), @r"
6178+
+----+-------+------+
6179+
| c1 | dummy | mark |
6180+
+----+-------+------+
6181+
| | 0 | |
6182+
| 1 | 10 | true |
6183+
| 4 | 40 | |
6184+
+----+-------+------+
6185+
");
6186+
}
6187+
6188+
Ok(())
6189+
}
6190+
6191+
/// Test null-aware left mark join when probe side is empty.
6192+
/// Expected: all rows are marked false, including NULL build keys.
6193+
#[apply(hash_join_exec_configs)]
6194+
#[tokio::test]
6195+
async fn test_null_aware_left_mark_empty_probe(batch_size: usize) -> Result<()> {
6196+
let task_ctx = prepare_task_ctx(batch_size, false);
6197+
6198+
let left = build_table_two_cols(
6199+
("c1", &vec![Some(1), None]),
6200+
("dummy", &vec![Some(10), Some(0)]),
61426201
);
6202+
6203+
let right = build_table_two_cols(
6204+
("c2", &Vec::<Option<i32>>::new()),
6205+
("dummy", &Vec::<Option<i32>>::new()),
6206+
);
6207+
6208+
let on = vec![(
6209+
Arc::new(Column::new_with_schema("c1", &left.schema())?) as _,
6210+
Arc::new(Column::new_with_schema("c2", &right.schema())?) as _,
6211+
)];
6212+
6213+
let join = HashJoinExec::try_new(
6214+
left,
6215+
right,
6216+
on,
6217+
None,
6218+
&JoinType::LeftMark,
6219+
None,
6220+
PartitionMode::CollectLeft,
6221+
NullEquality::NullEqualsNothing,
6222+
true, // null_aware = true
6223+
)?;
6224+
6225+
let stream = join.execute(0, task_ctx)?;
6226+
let batches = common::collect(stream).await?;
6227+
6228+
allow_duplicates! {
6229+
assert_snapshot!(batches_to_sort_string(&batches), @r"
6230+
+----+-------+-------+
6231+
| c1 | dummy | mark |
6232+
+----+-------+-------+
6233+
| | 0 | false |
6234+
| 1 | 10 | false |
6235+
+----+-------+-------+
6236+
");
6237+
}
6238+
6239+
Ok(())
61436240
}
61446241

61456242
#[test]

0 commit comments

Comments
 (0)