@@ -1899,12 +1899,15 @@ pub(super) fn equal_rows_arr(
18991899 null_equality,
19001900 )
19011901 } else {
1902+ let left_arrays_for_key: Vec < & ArrayRef > = left_arrays_per_batch
1903+ . iter ( )
1904+ . map ( |batch_keys| & batch_keys[ key_idx] )
1905+ . collect ( ) ;
19021906 compare_rows_elementwise_multi (
19031907 & mut equal_bits,
19041908 left_indices,
19051909 right_indices,
1906- left_arrays_per_batch,
1907- key_idx,
1910+ & left_arrays_for_key,
19081911 right_array,
19091912 null_equality,
19101913 )
@@ -1975,6 +1978,54 @@ fn and_bitmap_with_boolean_buffer(
19751978 }
19761979}
19771980
1981+ /// Dispatch a macro call by Arrow DataType for element-wise comparison.
1982+ /// The `$action` macro is invoked with the concrete array type for each supported type.
1983+ /// Returns `false` for unsupported/nested types (caller should use fallback).
1984+ macro_rules! dispatch_elementwise {
1985+ ( $data_type: expr, $equal_bits: expr, $left_indices: expr, $null_equality: expr, $action: ident) => {
1986+ match $data_type {
1987+ DataType :: Null => {
1988+ match $null_equality {
1989+ NullEquality :: NullEqualsNothing => {
1990+ for i in 0 ..$left_indices. len( ) {
1991+ $equal_bits. set_bit( i, false ) ;
1992+ }
1993+ }
1994+ NullEquality :: NullEqualsNull => { }
1995+ }
1996+ }
1997+ DataType :: Boolean => $action!( BooleanArray ) ,
1998+ DataType :: Int8 => $action!( Int8Array ) ,
1999+ DataType :: Int16 => $action!( Int16Array ) ,
2000+ DataType :: Int32 => $action!( Int32Array ) ,
2001+ DataType :: Int64 => $action!( Int64Array ) ,
2002+ DataType :: UInt8 => $action!( UInt8Array ) ,
2003+ DataType :: UInt16 => $action!( UInt16Array ) ,
2004+ DataType :: UInt32 => $action!( UInt32Array ) ,
2005+ DataType :: UInt64 => $action!( UInt64Array ) ,
2006+ DataType :: Float32 => $action!( Float32Array ) ,
2007+ DataType :: Float64 => $action!( Float64Array ) ,
2008+ DataType :: Binary => $action!( BinaryArray ) ,
2009+ DataType :: BinaryView => $action!( BinaryViewArray ) ,
2010+ DataType :: FixedSizeBinary ( _) => $action!( FixedSizeBinaryArray ) ,
2011+ DataType :: LargeBinary => $action!( LargeBinaryArray ) ,
2012+ DataType :: Utf8 => $action!( StringArray ) ,
2013+ DataType :: Utf8View => $action!( StringViewArray ) ,
2014+ DataType :: LargeUtf8 => $action!( LargeStringArray ) ,
2015+ DataType :: Decimal128 ( ..) => $action!( Decimal128Array ) ,
2016+ DataType :: Timestamp ( time_unit, None ) => match time_unit {
2017+ TimeUnit :: Second => $action!( TimestampSecondArray ) ,
2018+ TimeUnit :: Millisecond => $action!( TimestampMillisecondArray ) ,
2019+ TimeUnit :: Microsecond => $action!( TimestampMicrosecondArray ) ,
2020+ TimeUnit :: Nanosecond => $action!( TimestampNanosecondArray ) ,
2021+ } ,
2022+ DataType :: Date32 => $action!( Date32Array ) ,
2023+ DataType :: Date64 => $action!( Date64Array ) ,
2024+ _ => return false ,
2025+ }
2026+ } ;
2027+ }
2028+
19782029/// Compare rows element-wise without materializing intermediate arrays.
19792030/// Returns `true` if the comparison was handled, `false` if fallback is needed.
19802031///
@@ -1989,7 +2040,6 @@ fn compare_rows_elementwise(
19892040 right_array : & ArrayRef ,
19902041 null_equality : NullEquality ,
19912042) -> bool {
1992- // Nested types need special comparison logic, fall back
19932043 if left_array. data_type ( ) . is_nested ( ) {
19942044 return false ;
19952045 }
@@ -2009,47 +2059,7 @@ fn compare_rows_elementwise(
20092059 } } ;
20102060 }
20112061
2012- match left_array. data_type ( ) {
2013- DataType :: Null => {
2014- match null_equality {
2015- NullEquality :: NullEqualsNothing => {
2016- // null != null, clear all bits
2017- for i in 0 ..left_indices. len ( ) {
2018- equal_bits. set_bit ( i, false ) ;
2019- }
2020- }
2021- NullEquality :: NullEqualsNull => { } // null == null, keep bits
2022- }
2023- }
2024- DataType :: Boolean => compare_elementwise ! ( BooleanArray ) ,
2025- DataType :: Int8 => compare_elementwise ! ( Int8Array ) ,
2026- DataType :: Int16 => compare_elementwise ! ( Int16Array ) ,
2027- DataType :: Int32 => compare_elementwise ! ( Int32Array ) ,
2028- DataType :: Int64 => compare_elementwise ! ( Int64Array ) ,
2029- DataType :: UInt8 => compare_elementwise ! ( UInt8Array ) ,
2030- DataType :: UInt16 => compare_elementwise ! ( UInt16Array ) ,
2031- DataType :: UInt32 => compare_elementwise ! ( UInt32Array ) ,
2032- DataType :: UInt64 => compare_elementwise ! ( UInt64Array ) ,
2033- DataType :: Float32 => compare_elementwise ! ( Float32Array ) ,
2034- DataType :: Float64 => compare_elementwise ! ( Float64Array ) ,
2035- DataType :: Binary => compare_elementwise ! ( BinaryArray ) ,
2036- DataType :: BinaryView => compare_elementwise ! ( BinaryViewArray ) ,
2037- DataType :: FixedSizeBinary ( _) => compare_elementwise ! ( FixedSizeBinaryArray ) ,
2038- DataType :: LargeBinary => compare_elementwise ! ( LargeBinaryArray ) ,
2039- DataType :: Utf8 => compare_elementwise ! ( StringArray ) ,
2040- DataType :: Utf8View => compare_elementwise ! ( StringViewArray ) ,
2041- DataType :: LargeUtf8 => compare_elementwise ! ( LargeStringArray ) ,
2042- DataType :: Decimal128 ( ..) => compare_elementwise ! ( Decimal128Array ) ,
2043- DataType :: Timestamp ( time_unit, None ) => match time_unit {
2044- TimeUnit :: Second => compare_elementwise ! ( TimestampSecondArray ) ,
2045- TimeUnit :: Millisecond => compare_elementwise ! ( TimestampMillisecondArray ) ,
2046- TimeUnit :: Microsecond => compare_elementwise ! ( TimestampMicrosecondArray ) ,
2047- TimeUnit :: Nanosecond => compare_elementwise ! ( TimestampNanosecondArray ) ,
2048- } ,
2049- DataType :: Date32 => compare_elementwise ! ( Date32Array ) ,
2050- DataType :: Date64 => compare_elementwise ! ( Date64Array ) ,
2051- _ => return false , // Unsupported type, use fallback
2052- }
2062+ dispatch_elementwise ! ( left_array. data_type( ) , equal_bits, left_indices, null_equality, compare_elementwise) ;
20532063 true
20542064}
20552065
@@ -2116,8 +2126,7 @@ fn compare_rows_elementwise_multi(
21162126 equal_bits : & mut BooleanBufferBuilder ,
21172127 left_indices : & [ u64 ] ,
21182128 right_indices : & [ u32 ] ,
2119- left_arrays_per_batch : & [ Vec < ArrayRef > ] ,
2120- key_idx : usize ,
2129+ left_arrays : & [ & ArrayRef ] ,
21212130 right_array : & ArrayRef ,
21222131 null_equality : NullEquality ,
21232132) -> bool {
@@ -2127,9 +2136,9 @@ fn compare_rows_elementwise_multi(
21272136
21282137 macro_rules! compare_multi {
21292138 ( $array_type: ty) => { {
2130- let left_typed: Vec <& $array_type> = left_arrays_per_batch
2139+ let left_typed: Vec <& $array_type> = left_arrays
21312140 . iter( )
2132- . map( |keys| keys [ key_idx ] . as_any( ) . downcast_ref:: <$array_type>( ) . unwrap( ) )
2141+ . map( |a| a . as_any( ) . downcast_ref:: <$array_type>( ) . unwrap( ) )
21332142 . collect( ) ;
21342143 let right = right_array. as_any( ) . downcast_ref:: <$array_type>( ) . unwrap( ) ;
21352144 do_compare_elementwise_multi(
@@ -2143,46 +2152,7 @@ fn compare_rows_elementwise_multi(
21432152 } } ;
21442153 }
21452154
2146- match right_array. data_type ( ) {
2147- DataType :: Null => {
2148- match null_equality {
2149- NullEquality :: NullEqualsNothing => {
2150- for i in 0 ..left_indices. len ( ) {
2151- equal_bits. set_bit ( i, false ) ;
2152- }
2153- }
2154- NullEquality :: NullEqualsNull => { }
2155- }
2156- }
2157- DataType :: Boolean => compare_multi ! ( BooleanArray ) ,
2158- DataType :: Int8 => compare_multi ! ( Int8Array ) ,
2159- DataType :: Int16 => compare_multi ! ( Int16Array ) ,
2160- DataType :: Int32 => compare_multi ! ( Int32Array ) ,
2161- DataType :: Int64 => compare_multi ! ( Int64Array ) ,
2162- DataType :: UInt8 => compare_multi ! ( UInt8Array ) ,
2163- DataType :: UInt16 => compare_multi ! ( UInt16Array ) ,
2164- DataType :: UInt32 => compare_multi ! ( UInt32Array ) ,
2165- DataType :: UInt64 => compare_multi ! ( UInt64Array ) ,
2166- DataType :: Float32 => compare_multi ! ( Float32Array ) ,
2167- DataType :: Float64 => compare_multi ! ( Float64Array ) ,
2168- DataType :: Binary => compare_multi ! ( BinaryArray ) ,
2169- DataType :: BinaryView => compare_multi ! ( BinaryViewArray ) ,
2170- DataType :: FixedSizeBinary ( _) => compare_multi ! ( FixedSizeBinaryArray ) ,
2171- DataType :: LargeBinary => compare_multi ! ( LargeBinaryArray ) ,
2172- DataType :: Utf8 => compare_multi ! ( StringArray ) ,
2173- DataType :: Utf8View => compare_multi ! ( StringViewArray ) ,
2174- DataType :: LargeUtf8 => compare_multi ! ( LargeStringArray ) ,
2175- DataType :: Decimal128 ( ..) => compare_multi ! ( Decimal128Array ) ,
2176- DataType :: Timestamp ( time_unit, None ) => match time_unit {
2177- TimeUnit :: Second => compare_multi ! ( TimestampSecondArray ) ,
2178- TimeUnit :: Millisecond => compare_multi ! ( TimestampMillisecondArray ) ,
2179- TimeUnit :: Microsecond => compare_multi ! ( TimestampMicrosecondArray ) ,
2180- TimeUnit :: Nanosecond => compare_multi ! ( TimestampNanosecondArray ) ,
2181- } ,
2182- DataType :: Date32 => compare_multi ! ( Date32Array ) ,
2183- DataType :: Date64 => compare_multi ! ( Date64Array ) ,
2184- _ => return false ,
2185- }
2155+ dispatch_elementwise ! ( right_array. data_type( ) , equal_bits, left_indices, null_equality, compare_multi) ;
21862156 true
21872157}
21882158
@@ -2222,6 +2192,10 @@ fn do_compare_elementwise_multi<A: ArrayAccessor>(
22222192 }
22232193 }
22242194 } else {
2195+ // Pre-compute null buffers per batch to avoid repeated method calls in the loop
2196+ let left_nulls_per_batch: Vec < Option < & NullBuffer > > =
2197+ left_arrays. iter ( ) . map ( |a| a. nulls ( ) ) . collect ( ) ;
2198+
22252199 for i in 0 ..num_rows {
22262200 if !equal_bits. get_bit ( i) {
22272201 continue ;
@@ -2230,14 +2204,16 @@ fn do_compare_elementwise_multi<A: ArrayAccessor>(
22302204 let batch_idx = ( packed >> 32 ) as usize ;
22312205 let row_idx = ( packed & 0xFFFFFFFF ) as usize ;
22322206 let r_idx = right_indices[ i] as usize ;
2233- let left = & left_arrays [ batch_idx] ;
2234- let l_null = left . nulls ( ) . is_some_and ( |n| !n. is_valid ( row_idx) ) ;
2207+ let l_null = left_nulls_per_batch [ batch_idx]
2208+ . is_some_and ( |n| !n. is_valid ( row_idx) ) ;
22352209 let r_null = right_nulls. is_some_and ( |n| !n. is_valid ( r_idx) ) ;
22362210
22372211 let is_equal = match ( l_null, r_null) {
22382212 ( true , true ) => null_equality == NullEquality :: NullEqualsNull ,
22392213 ( true , false ) | ( false , true ) => false ,
2240- ( false , false ) => left. value ( row_idx) == right. value ( r_idx) ,
2214+ ( false , false ) => {
2215+ left_arrays[ batch_idx] . value ( row_idx) == right. value ( r_idx)
2216+ }
22412217 } ;
22422218 if !is_equal {
22432219 equal_bits. set_bit ( i, false ) ;
0 commit comments