Skip to content

Commit 7f5db8a

Browse files
Dandandanclaude
andcommitted
Simplify: extract shared dispatch_elementwise macro, clean up multi-batch API
- Extract duplicated DataType match block into dispatch_elementwise! macro used by both single-batch and multi-batch comparison functions - Change compare_rows_elementwise_multi to take &[&ArrayRef] instead of leaking (left_arrays_per_batch, key_idx) internal structure - Pre-compute left null buffers per batch outside the hot loop in do_compare_elementwise_multi Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 42aa7b2 commit 7f5db8a

1 file changed

Lines changed: 67 additions & 91 deletions

File tree

  • datafusion/physical-plan/src/joins

datafusion/physical-plan/src/joins/utils.rs

Lines changed: 67 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,12 +1899,15 @@ pub(super) fn equal_rows_arr(
18991899
null_equality,
19001900
)
19011901
} else {
1902+
let left_arrays_for_key: Vec<&ArrayRef> = left_arrays_per_batch
1903+
.iter()
1904+
.map(|batch_keys| &batch_keys[key_idx])
1905+
.collect();
19021906
compare_rows_elementwise_multi(
19031907
&mut equal_bits,
19041908
left_indices,
19051909
right_indices,
1906-
left_arrays_per_batch,
1907-
key_idx,
1910+
&left_arrays_for_key,
19081911
right_array,
19091912
null_equality,
19101913
)
@@ -1975,6 +1978,54 @@ fn and_bitmap_with_boolean_buffer(
19751978
}
19761979
}
19771980

1981+
/// Dispatch a macro call by Arrow DataType for element-wise comparison.
1982+
/// The `$action` macro is invoked with the concrete array type for each supported type.
1983+
/// Returns `false` for unsupported/nested types (caller should use fallback).
1984+
macro_rules! dispatch_elementwise {
1985+
($data_type:expr, $equal_bits:expr, $left_indices:expr, $null_equality:expr, $action:ident) => {
1986+
match $data_type {
1987+
DataType::Null => {
1988+
match $null_equality {
1989+
NullEquality::NullEqualsNothing => {
1990+
for i in 0..$left_indices.len() {
1991+
$equal_bits.set_bit(i, false);
1992+
}
1993+
}
1994+
NullEquality::NullEqualsNull => {}
1995+
}
1996+
}
1997+
DataType::Boolean => $action!(BooleanArray),
1998+
DataType::Int8 => $action!(Int8Array),
1999+
DataType::Int16 => $action!(Int16Array),
2000+
DataType::Int32 => $action!(Int32Array),
2001+
DataType::Int64 => $action!(Int64Array),
2002+
DataType::UInt8 => $action!(UInt8Array),
2003+
DataType::UInt16 => $action!(UInt16Array),
2004+
DataType::UInt32 => $action!(UInt32Array),
2005+
DataType::UInt64 => $action!(UInt64Array),
2006+
DataType::Float32 => $action!(Float32Array),
2007+
DataType::Float64 => $action!(Float64Array),
2008+
DataType::Binary => $action!(BinaryArray),
2009+
DataType::BinaryView => $action!(BinaryViewArray),
2010+
DataType::FixedSizeBinary(_) => $action!(FixedSizeBinaryArray),
2011+
DataType::LargeBinary => $action!(LargeBinaryArray),
2012+
DataType::Utf8 => $action!(StringArray),
2013+
DataType::Utf8View => $action!(StringViewArray),
2014+
DataType::LargeUtf8 => $action!(LargeStringArray),
2015+
DataType::Decimal128(..) => $action!(Decimal128Array),
2016+
DataType::Timestamp(time_unit, None) => match time_unit {
2017+
TimeUnit::Second => $action!(TimestampSecondArray),
2018+
TimeUnit::Millisecond => $action!(TimestampMillisecondArray),
2019+
TimeUnit::Microsecond => $action!(TimestampMicrosecondArray),
2020+
TimeUnit::Nanosecond => $action!(TimestampNanosecondArray),
2021+
},
2022+
DataType::Date32 => $action!(Date32Array),
2023+
DataType::Date64 => $action!(Date64Array),
2024+
_ => return false,
2025+
}
2026+
};
2027+
}
2028+
19782029
/// Compare rows element-wise without materializing intermediate arrays.
19792030
/// Returns `true` if the comparison was handled, `false` if fallback is needed.
19802031
///
@@ -1989,7 +2040,6 @@ fn compare_rows_elementwise(
19892040
right_array: &ArrayRef,
19902041
null_equality: NullEquality,
19912042
) -> bool {
1992-
// Nested types need special comparison logic, fall back
19932043
if left_array.data_type().is_nested() {
19942044
return false;
19952045
}
@@ -2009,47 +2059,7 @@ fn compare_rows_elementwise(
20092059
}};
20102060
}
20112061

2012-
match left_array.data_type() {
2013-
DataType::Null => {
2014-
match null_equality {
2015-
NullEquality::NullEqualsNothing => {
2016-
// null != null, clear all bits
2017-
for i in 0..left_indices.len() {
2018-
equal_bits.set_bit(i, false);
2019-
}
2020-
}
2021-
NullEquality::NullEqualsNull => {} // null == null, keep bits
2022-
}
2023-
}
2024-
DataType::Boolean => compare_elementwise!(BooleanArray),
2025-
DataType::Int8 => compare_elementwise!(Int8Array),
2026-
DataType::Int16 => compare_elementwise!(Int16Array),
2027-
DataType::Int32 => compare_elementwise!(Int32Array),
2028-
DataType::Int64 => compare_elementwise!(Int64Array),
2029-
DataType::UInt8 => compare_elementwise!(UInt8Array),
2030-
DataType::UInt16 => compare_elementwise!(UInt16Array),
2031-
DataType::UInt32 => compare_elementwise!(UInt32Array),
2032-
DataType::UInt64 => compare_elementwise!(UInt64Array),
2033-
DataType::Float32 => compare_elementwise!(Float32Array),
2034-
DataType::Float64 => compare_elementwise!(Float64Array),
2035-
DataType::Binary => compare_elementwise!(BinaryArray),
2036-
DataType::BinaryView => compare_elementwise!(BinaryViewArray),
2037-
DataType::FixedSizeBinary(_) => compare_elementwise!(FixedSizeBinaryArray),
2038-
DataType::LargeBinary => compare_elementwise!(LargeBinaryArray),
2039-
DataType::Utf8 => compare_elementwise!(StringArray),
2040-
DataType::Utf8View => compare_elementwise!(StringViewArray),
2041-
DataType::LargeUtf8 => compare_elementwise!(LargeStringArray),
2042-
DataType::Decimal128(..) => compare_elementwise!(Decimal128Array),
2043-
DataType::Timestamp(time_unit, None) => match time_unit {
2044-
TimeUnit::Second => compare_elementwise!(TimestampSecondArray),
2045-
TimeUnit::Millisecond => compare_elementwise!(TimestampMillisecondArray),
2046-
TimeUnit::Microsecond => compare_elementwise!(TimestampMicrosecondArray),
2047-
TimeUnit::Nanosecond => compare_elementwise!(TimestampNanosecondArray),
2048-
},
2049-
DataType::Date32 => compare_elementwise!(Date32Array),
2050-
DataType::Date64 => compare_elementwise!(Date64Array),
2051-
_ => return false, // Unsupported type, use fallback
2052-
}
2062+
dispatch_elementwise!(left_array.data_type(), equal_bits, left_indices, null_equality, compare_elementwise);
20532063
true
20542064
}
20552065

@@ -2116,8 +2126,7 @@ fn compare_rows_elementwise_multi(
21162126
equal_bits: &mut BooleanBufferBuilder,
21172127
left_indices: &[u64],
21182128
right_indices: &[u32],
2119-
left_arrays_per_batch: &[Vec<ArrayRef>],
2120-
key_idx: usize,
2129+
left_arrays: &[&ArrayRef],
21212130
right_array: &ArrayRef,
21222131
null_equality: NullEquality,
21232132
) -> bool {
@@ -2127,9 +2136,9 @@ fn compare_rows_elementwise_multi(
21272136

21282137
macro_rules! compare_multi {
21292138
($array_type:ty) => {{
2130-
let left_typed: Vec<&$array_type> = left_arrays_per_batch
2139+
let left_typed: Vec<&$array_type> = left_arrays
21312140
.iter()
2132-
.map(|keys| keys[key_idx].as_any().downcast_ref::<$array_type>().unwrap())
2141+
.map(|a| a.as_any().downcast_ref::<$array_type>().unwrap())
21332142
.collect();
21342143
let right = right_array.as_any().downcast_ref::<$array_type>().unwrap();
21352144
do_compare_elementwise_multi(
@@ -2143,46 +2152,7 @@ fn compare_rows_elementwise_multi(
21432152
}};
21442153
}
21452154

2146-
match right_array.data_type() {
2147-
DataType::Null => {
2148-
match null_equality {
2149-
NullEquality::NullEqualsNothing => {
2150-
for i in 0..left_indices.len() {
2151-
equal_bits.set_bit(i, false);
2152-
}
2153-
}
2154-
NullEquality::NullEqualsNull => {}
2155-
}
2156-
}
2157-
DataType::Boolean => compare_multi!(BooleanArray),
2158-
DataType::Int8 => compare_multi!(Int8Array),
2159-
DataType::Int16 => compare_multi!(Int16Array),
2160-
DataType::Int32 => compare_multi!(Int32Array),
2161-
DataType::Int64 => compare_multi!(Int64Array),
2162-
DataType::UInt8 => compare_multi!(UInt8Array),
2163-
DataType::UInt16 => compare_multi!(UInt16Array),
2164-
DataType::UInt32 => compare_multi!(UInt32Array),
2165-
DataType::UInt64 => compare_multi!(UInt64Array),
2166-
DataType::Float32 => compare_multi!(Float32Array),
2167-
DataType::Float64 => compare_multi!(Float64Array),
2168-
DataType::Binary => compare_multi!(BinaryArray),
2169-
DataType::BinaryView => compare_multi!(BinaryViewArray),
2170-
DataType::FixedSizeBinary(_) => compare_multi!(FixedSizeBinaryArray),
2171-
DataType::LargeBinary => compare_multi!(LargeBinaryArray),
2172-
DataType::Utf8 => compare_multi!(StringArray),
2173-
DataType::Utf8View => compare_multi!(StringViewArray),
2174-
DataType::LargeUtf8 => compare_multi!(LargeStringArray),
2175-
DataType::Decimal128(..) => compare_multi!(Decimal128Array),
2176-
DataType::Timestamp(time_unit, None) => match time_unit {
2177-
TimeUnit::Second => compare_multi!(TimestampSecondArray),
2178-
TimeUnit::Millisecond => compare_multi!(TimestampMillisecondArray),
2179-
TimeUnit::Microsecond => compare_multi!(TimestampMicrosecondArray),
2180-
TimeUnit::Nanosecond => compare_multi!(TimestampNanosecondArray),
2181-
},
2182-
DataType::Date32 => compare_multi!(Date32Array),
2183-
DataType::Date64 => compare_multi!(Date64Array),
2184-
_ => return false,
2185-
}
2155+
dispatch_elementwise!(right_array.data_type(), equal_bits, left_indices, null_equality, compare_multi);
21862156
true
21872157
}
21882158

@@ -2222,6 +2192,10 @@ fn do_compare_elementwise_multi<A: ArrayAccessor>(
22222192
}
22232193
}
22242194
} else {
2195+
// Pre-compute null buffers per batch to avoid repeated method calls in the loop
2196+
let left_nulls_per_batch: Vec<Option<&NullBuffer>> =
2197+
left_arrays.iter().map(|a| a.nulls()).collect();
2198+
22252199
for i in 0..num_rows {
22262200
if !equal_bits.get_bit(i) {
22272201
continue;
@@ -2230,14 +2204,16 @@ fn do_compare_elementwise_multi<A: ArrayAccessor>(
22302204
let batch_idx = (packed >> 32) as usize;
22312205
let row_idx = (packed & 0xFFFFFFFF) as usize;
22322206
let r_idx = right_indices[i] as usize;
2233-
let left = &left_arrays[batch_idx];
2234-
let l_null = left.nulls().is_some_and(|n| !n.is_valid(row_idx));
2207+
let l_null = left_nulls_per_batch[batch_idx]
2208+
.is_some_and(|n| !n.is_valid(row_idx));
22352209
let r_null = right_nulls.is_some_and(|n| !n.is_valid(r_idx));
22362210

22372211
let is_equal = match (l_null, r_null) {
22382212
(true, true) => null_equality == NullEquality::NullEqualsNull,
22392213
(true, false) | (false, true) => false,
2240-
(false, false) => left.value(row_idx) == right.value(r_idx),
2214+
(false, false) => {
2215+
left_arrays[batch_idx].value(row_idx) == right.value(r_idx)
2216+
}
22412217
};
22422218
if !is_equal {
22432219
equal_bits.set_bit(i, false);

0 commit comments

Comments
 (0)