Skip to content

Commit 88ad0f8

Browse files
perf(in_list): add PrimitiveFilter for integer types
Adds PrimitiveFilter<T> - a generic HashSet-based filter for primitive types. Also adds contains_slice() method for zero-copy buffer access used by type reinterpretation. The strategy module now uses PrimitiveFilter directly for Int32/Int64/UInt32/UInt64 instead of the macro-generated filters, while keeping Float32/Float64 with OrderedFloat wrappers for now.
1 parent a997d7f commit 88ad0f8

3 files changed

Lines changed: 132 additions & 142 deletions

File tree

datafusion/physical-expr/src/expressions/in_list/primitive.rs

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
//!
2020
//! This module provides high-performance membership testing for Arrow primitive types.
2121
22+
use std::hash::Hash;
23+
2224
use arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
2325
use arrow::datatypes::ArrowPrimitiveType;
24-
use datafusion_common::{Result, exec_datafusion_err};
26+
use datafusion_common::{HashSet, Result, exec_datafusion_err};
2527

2628
use super::array_filter::StaticFilter;
2729
use super::result::{build_in_list_result, handle_dictionary};
@@ -171,3 +173,78 @@ impl<C: BitmapFilterConfig> StaticFilter for BitmapFilter<C> {
171173
))
172174
}
173175
}
176+
177+
// =============================================================================
178+
// PRIMITIVE FILTER (Hash-based)
179+
// =============================================================================
180+
181+
/// Hash-based filter for primitive types with larger IN lists.
182+
pub(crate) struct PrimitiveFilter<T: ArrowPrimitiveType> {
183+
null_count: usize,
184+
set: HashSet<T::Native>,
185+
}
186+
187+
impl<T: ArrowPrimitiveType> PrimitiveFilter<T>
188+
where
189+
T::Native: Hash + Eq,
190+
{
191+
pub(crate) fn try_new(in_array: &ArrayRef) -> Result<Self> {
192+
let arr = in_array.as_primitive_opt::<T>().ok_or_else(|| {
193+
exec_datafusion_err!(
194+
"PrimitiveFilter: expected {} array",
195+
std::any::type_name::<T>()
196+
)
197+
})?;
198+
Ok(Self {
199+
null_count: arr.null_count(),
200+
set: arr.iter().flatten().collect(),
201+
})
202+
}
203+
204+
/// Check membership using a raw values slice (zero-copy path for type reinterpretation).
205+
#[inline]
206+
pub(crate) fn contains_slice(
207+
&self,
208+
values: &[T::Native],
209+
nulls: Option<&arrow::buffer::NullBuffer>,
210+
negated: bool,
211+
) -> BooleanArray {
212+
build_in_list_result(
213+
values.len(),
214+
nulls,
215+
self.null_count > 0,
216+
negated,
217+
// SAFETY: i is in bounds since we iterate 0..values.len()
218+
|i| self.set.contains(unsafe { values.get_unchecked(i) }),
219+
)
220+
}
221+
}
222+
223+
impl<T> StaticFilter for PrimitiveFilter<T>
224+
where
225+
T: ArrowPrimitiveType + 'static,
226+
T::Native: Hash + Eq + Send + Sync + 'static,
227+
{
228+
fn null_count(&self) -> usize {
229+
self.null_count
230+
}
231+
232+
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
233+
handle_dictionary!(self, v, negated);
234+
let v = v.as_primitive_opt::<T>().ok_or_else(|| {
235+
exec_datafusion_err!(
236+
"PrimitiveFilter: expected {} array",
237+
std::any::type_name::<T>()
238+
)
239+
})?;
240+
let values = v.values();
241+
Ok(build_in_list_result(
242+
v.len(),
243+
v.nulls(),
244+
self.null_count > 0,
245+
negated,
246+
// SAFETY: i is in bounds since we iterate 0..v.len()
247+
|i| self.set.contains(unsafe { values.get_unchecked(i) }),
248+
))
249+
}
250+
}

datafusion/physical-expr/src/expressions/in_list/strategy.rs

Lines changed: 5 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use arrow::datatypes::*;
2727
use datafusion_common::{HashSet, Result, exec_datafusion_err};
2828

2929
use super::array_filter::{ArrayStaticFilter, StaticFilter};
30-
use super::primitive::{U8Config, U16Config};
30+
use super::primitive::{PrimitiveFilter, U8Config, U16Config};
3131
use super::transform::make_bitmap_filter;
3232

3333
pub(crate) fn instantiate_static_filter(
@@ -39,11 +39,11 @@ pub(crate) fn instantiate_static_filter(
3939
// 2-byte types: use bitmap (65536 bits = 8 KB)
4040
DataType::Int16 | DataType::UInt16 => make_bitmap_filter::<U16Config>(&in_array),
4141
// 4-byte integer types
42-
DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)),
43-
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
42+
DataType::Int32 => Ok(Arc::new(PrimitiveFilter::<Int32Type>::try_new(&in_array)?)),
43+
DataType::UInt32 => Ok(Arc::new(PrimitiveFilter::<UInt32Type>::try_new(&in_array)?)),
4444
// 8-byte integer types
45-
DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)),
46-
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
45+
DataType::Int64 => Ok(Arc::new(PrimitiveFilter::<Int64Type>::try_new(&in_array)?)),
46+
DataType::UInt64 => Ok(Arc::new(PrimitiveFilter::<UInt64Type>::try_new(&in_array)?)),
4747
// Float primitive types (use ordered wrappers for Hash/Eq)
4848
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
4949
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
@@ -104,139 +104,6 @@ impl From<f64> for OrderedFloat64 {
104104
}
105105
}
106106

107-
// Macro to generate specialized StaticFilter implementations for primitive types
108-
macro_rules! primitive_static_filter {
109-
($Name:ident, $ArrowType:ty) => {
110-
struct $Name {
111-
null_count: usize,
112-
values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>,
113-
}
114-
115-
impl $Name {
116-
fn try_new(in_array: &ArrayRef) -> Result<Self> {
117-
let in_array = in_array
118-
.as_primitive_opt::<$ArrowType>()
119-
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
120-
121-
let mut values = HashSet::with_capacity(in_array.len());
122-
let null_count = in_array.null_count();
123-
124-
for v in in_array.iter().flatten() {
125-
values.insert(v);
126-
}
127-
128-
Ok(Self { null_count, values })
129-
}
130-
}
131-
132-
impl StaticFilter for $Name {
133-
fn null_count(&self) -> usize {
134-
self.null_count
135-
}
136-
137-
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
138-
// Handle dictionary arrays by recursing on the values
139-
downcast_dictionary_array! {
140-
v => {
141-
let values_contains = self.contains(v.values().as_ref(), negated)?;
142-
let result = take(&values_contains, v.keys(), None)?;
143-
return Ok(downcast_array(result.as_ref()))
144-
}
145-
_ => {}
146-
}
147-
148-
let v = v
149-
.as_primitive_opt::<$ArrowType>()
150-
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
151-
152-
let haystack_has_nulls = self.null_count > 0;
153-
154-
let needle_values = v.values();
155-
let needle_nulls = v.nulls();
156-
let needle_has_nulls = v.null_count() > 0;
157-
158-
// Truth table for `value [NOT] IN (set)` with SQL three-valued logic:
159-
// ("-" means the value doesn't affect the result)
160-
//
161-
// | needle_null | haystack_null | negated | in set? | result |
162-
// |-------------|---------------|---------|---------|--------|
163-
// | true | - | false | - | null |
164-
// | true | - | true | - | null |
165-
// | false | true | false | yes | true |
166-
// | false | true | false | no | null |
167-
// | false | true | true | yes | false |
168-
// | false | true | true | no | null |
169-
// | false | false | false | yes | true |
170-
// | false | false | false | no | false |
171-
// | false | false | true | yes | false |
172-
// | false | false | true | no | true |
173-
174-
// Compute the "contains" result using collect_bool (fast batched approach)
175-
// This ignores nulls - we handle them separately
176-
let contains_buffer = if negated {
177-
BooleanBuffer::collect_bool(needle_values.len(), |i| {
178-
!self.values.contains(&needle_values[i])
179-
})
180-
} else {
181-
BooleanBuffer::collect_bool(needle_values.len(), |i| {
182-
self.values.contains(&needle_values[i])
183-
})
184-
};
185-
186-
// Compute the null mask
187-
// Output is null when:
188-
// 1. needle value is null, OR
189-
// 2. needle value is not in set AND haystack has nulls
190-
let result_nulls = match (needle_has_nulls, haystack_has_nulls) {
191-
(false, false) => {
192-
// No nulls anywhere
193-
None
194-
}
195-
(true, false) => {
196-
// Only needle has nulls - just use needle's null mask
197-
needle_nulls.cloned()
198-
}
199-
(false, true) => {
200-
// Only haystack has nulls - result is null when value not in set
201-
// Valid (not null) when original "in set" is true
202-
// For NOT IN: contains_buffer = !original, so validity = !contains_buffer
203-
let validity = if negated {
204-
!&contains_buffer
205-
} else {
206-
contains_buffer.clone()
207-
};
208-
Some(NullBuffer::new(validity))
209-
}
210-
(true, true) => {
211-
// Both have nulls - combine needle nulls with haystack-induced nulls
212-
let needle_validity = needle_nulls.map(|n| n.inner().clone())
213-
.unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len()));
214-
215-
// Valid when original "in set" is true (see above)
216-
let haystack_validity = if negated {
217-
!&contains_buffer
218-
} else {
219-
contains_buffer.clone()
220-
};
221-
222-
// Combined validity: valid only where both are valid
223-
let combined_validity = &needle_validity & &haystack_validity;
224-
Some(NullBuffer::new(combined_validity))
225-
}
226-
};
227-
228-
Ok(BooleanArray::new(contains_buffer, result_nulls))
229-
}
230-
}
231-
};
232-
}
233-
234-
// Generate specialized filters for 4-byte and 8-byte integer primitive types
235-
// (1-byte and 2-byte types use BitmapFilter instead)
236-
primitive_static_filter!(Int32StaticFilter, Int32Type);
237-
primitive_static_filter!(Int64StaticFilter, Int64Type);
238-
primitive_static_filter!(UInt32StaticFilter, UInt32Type);
239-
primitive_static_filter!(UInt64StaticFilter, UInt64Type);
240107

241108
// Macro to generate specialized StaticFilter implementations for float types
242109
// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics

datafusion/physical-expr/src/expressions/in_list/transform.rs

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
//! This module provides type reinterpretation for optimizing filter dispatch.
2121
//! For equality comparison, only the bit pattern matters, so we can:
2222
//! - Reinterpret signed integers as unsigned (Int8 → UInt8)
23+
//! - Reinterpret floats as unsigned integers (Float64 → UInt64)
2324
//!
24-
//! This allows using a single filter implementation (e.g., BitmapFilter for UInt8)
25-
//! to handle multiple types (Int8, UInt8) that share the same byte width.
25+
//! This allows using a single filter implementation (e.g., for UInt64) to handle
26+
//! multiple types (Int64, Float64, Timestamp, Duration) that share the same
27+
//! byte width, reducing code duplication.
2628
29+
use std::hash::Hash;
2730
use std::sync::Arc;
2831

2932
use arrow::array::{Array, ArrayRef, BooleanArray, PrimitiveArray};
@@ -32,7 +35,7 @@ use arrow::datatypes::ArrowPrimitiveType;
3235
use datafusion_common::Result;
3336

3437
use super::array_filter::StaticFilter;
35-
use super::primitive::{BitmapFilter, BitmapFilterConfig};
38+
use super::primitive::{BitmapFilter, BitmapFilterConfig, PrimitiveFilter};
3639
use super::result::handle_dictionary;
3740

3841
// =============================================================================
@@ -59,6 +62,32 @@ impl<C: BitmapFilterConfig> StaticFilter for ReinterpretedBitmap<C> {
5962
}
6063
}
6164

65+
/// Reinterpreting filter for hash-based lookups.
66+
///
67+
/// Zero-copy: reinterprets input buffer directly as target type slice.
68+
struct ReinterpretedPrimitive<D: ArrowPrimitiveType> {
69+
inner: PrimitiveFilter<D>,
70+
}
71+
72+
impl<D> StaticFilter for ReinterpretedPrimitive<D>
73+
where
74+
D: ArrowPrimitiveType + 'static,
75+
D::Native: Hash + Eq + Send + Sync + 'static,
76+
{
77+
fn null_count(&self) -> usize {
78+
self.inner.null_count()
79+
}
80+
81+
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
82+
handle_dictionary!(self, v, negated);
83+
84+
let data = v.to_data();
85+
let values: &[D::Native] = data.buffers()[0].typed_data();
86+
87+
Ok(self.inner.contains_slice(values, data.nulls(), negated))
88+
}
89+
}
90+
6291
/// Reinterprets any primitive-like array as the target primitive type T by extracting
6392
/// the underlying buffer.
6493
///
@@ -87,3 +116,20 @@ where
87116
let inner = BitmapFilter::<C>::try_new(&reinterpreted)?;
88117
Ok(Arc::new(ReinterpretedBitmap { inner }))
89118
}
119+
120+
/// Creates a hash-based filter, reinterpreting types if needed.
121+
pub(crate) fn make_primitive_filter<D>(
122+
in_array: &ArrayRef,
123+
) -> Result<Arc<dyn StaticFilter + Send + Sync>>
124+
where
125+
D: ArrowPrimitiveType + 'static,
126+
D::Native: Hash + Eq + Send + Sync + 'static,
127+
{
128+
if in_array.data_type() == &D::DATA_TYPE {
129+
return Ok(Arc::new(PrimitiveFilter::<D>::try_new(in_array)?));
130+
}
131+
132+
let reinterpreted = reinterpret_any_primitive_to::<D>(in_array.as_ref());
133+
let inner = PrimitiveFilter::<D>::try_new(&reinterpreted)?;
134+
Ok(Arc::new(ReinterpretedPrimitive { inner }))
135+
}

0 commit comments

Comments
 (0)