Skip to content

Commit 31545cd

Browse files
Implement Branchless Filter for small primitive lists
Adds a const-generic unrolled comparison chain that avoids CPU branching. Outperforms hash lookups for very small lists. Triggers for primitives when list size <= 32 (4-byte), 16 (8-byte), or 4 (16-byte).
1 parent 18f47cf commit 31545cd

3 files changed

Lines changed: 344 additions & 19 deletions

File tree

datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,91 @@ impl<C: BitmapFilterConfig> StaticFilter for BitmapFilter<C> {
173173
}
174174
}
175175

176+
// =============================================================================
177+
// BRANCHLESS FILTER (Const Generic for Small Lists)
178+
// =============================================================================
179+
180+
/// A branchless filter for very small IN lists (0-16 elements).
181+
///
182+
/// Uses const generics to unroll the membership check into a fixed-size
183+
/// comparison chain, outperforming hash lookups for small lists due to:
184+
/// - No branching (uses bitwise OR to combine comparisons)
185+
/// - Better CPU pipelining
186+
/// - No hash computation overhead
187+
pub(crate) struct BranchlessFilter<T: ArrowPrimitiveType, const N: usize> {
188+
null_count: usize,
189+
values: [T::Native; N],
190+
}
191+
192+
impl<T: ArrowPrimitiveType, const N: usize> BranchlessFilter<T, N>
193+
where
194+
T::Native: Copy + PartialEq,
195+
{
196+
/// Try to create a branchless filter if the array has exactly N non-null values.
197+
pub(crate) fn try_new(in_array: &ArrayRef) -> Option<Result<Self>> {
198+
let in_array = in_array.as_primitive_opt::<T>()?;
199+
let non_null_count = in_array.len() - in_array.null_count();
200+
if non_null_count != N {
201+
return None;
202+
}
203+
let values: Vec<_> = in_array.iter().flatten().collect();
204+
// Use default_value() from ArrowPrimitiveType trait instead of Default::default()
205+
let mut arr = [T::default_value(); N];
206+
arr.copy_from_slice(&values);
207+
Some(Ok(Self {
208+
null_count: in_array.null_count(),
209+
values: arr,
210+
}))
211+
}
212+
213+
/// Branchless membership check using OR-chain.
214+
#[inline(always)]
215+
fn check(&self, needle: T::Native) -> bool {
216+
self.values
217+
.iter()
218+
.fold(false, |acc, &v| acc | (v == needle))
219+
}
220+
221+
/// Check membership using a raw values slice (zero-copy path for type reinterpretation).
222+
#[inline]
223+
pub(crate) fn contains_slice(
224+
&self,
225+
values: &[T::Native],
226+
nulls: Option<&NullBuffer>,
227+
negated: bool,
228+
) -> BooleanArray {
229+
build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| {
230+
self.check(unsafe { *values.get_unchecked(i) })
231+
})
232+
}
233+
}
234+
235+
impl<T: ArrowPrimitiveType, const N: usize> StaticFilter for BranchlessFilter<T, N>
236+
where
237+
T::Native: Copy + PartialEq + Send + Sync,
238+
{
239+
fn null_count(&self) -> usize {
240+
self.null_count
241+
}
242+
243+
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
244+
handle_dictionary!(self, v, negated);
245+
let v = v.as_primitive_opt::<T>().ok_or_else(|| {
246+
exec_datafusion_err!("Failed to downcast array to primitive type")
247+
})?;
248+
let input_values = v.values();
249+
Ok(build_in_list_result(
250+
v.len(),
251+
v.nulls(),
252+
self.null_count > 0,
253+
negated,
254+
// SAFETY: i is in bounds since we iterate 0..v.len()
255+
#[inline(always)]
256+
|i| self.check(unsafe { *input_values.get_unchecked(i) }),
257+
))
258+
}
259+
}
260+
176261
// =============================================================================
177262
// LEGACY FILTERS (to be replaced by optimized ones in subsequent commits)
178263
// =============================================================================

datafusion/physical-expr/src/expressions/in_list/strategy.rs

Lines changed: 112 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,125 @@
1818
use std::sync::Arc;
1919

2020
use arrow::array::ArrayRef;
21-
use arrow::datatypes::DataType;
22-
use datafusion_common::Result;
21+
use arrow::datatypes::*;
22+
use datafusion_common::{Result, exec_datafusion_err};
2323

2424
use super::array_static_filter::ArrayStaticFilter;
2525
use super::primitive_filter::*;
2626
use super::static_filter::StaticFilter;
27-
use super::transform::make_bitmap_filter;
27+
use super::transform::{make_bitmap_filter, make_branchless_filter};
2828

29+
// =============================================================================
30+
// LOOKUP STRATEGY THRESHOLDS (tuned via microbenchmarks)
31+
// =============================================================================
32+
33+
/// Maximum list size for branchless lookup on 4-byte primitives (Int32, UInt32, Float32).
34+
const BRANCHLESS_MAX_4B: usize = 32;
35+
36+
/// Maximum list size for branchless lookup on 8-byte primitives (Int64, UInt64, Float64).
37+
const BRANCHLESS_MAX_8B: usize = 16;
38+
39+
/// Maximum list size for branchless lookup on 16-byte types (Decimal128).
40+
const BRANCHLESS_MAX_16B: usize = 4;
41+
42+
// =============================================================================
43+
// FILTER STRATEGY SELECTION
44+
// =============================================================================
45+
46+
/// The lookup strategy to use for a given data type and list size.
47+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48+
enum FilterStrategy {
49+
/// Bitmap filter for u8/u16 - O(1) bit test, always fastest for these types.
50+
Bitmap1B,
51+
Bitmap2B,
52+
/// Branchless OR-chain for small lists.
53+
Branchless,
54+
/// Generic ArrayStaticFilter fallback.
55+
Generic,
56+
}
57+
58+
/// Determines the optimal lookup strategy based on data type and list size.
59+
fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy {
60+
match dt.primitive_width() {
61+
Some(1) => FilterStrategy::Bitmap1B,
62+
Some(2) => FilterStrategy::Bitmap2B,
63+
Some(4) => {
64+
if len <= BRANCHLESS_MAX_4B {
65+
FilterStrategy::Branchless
66+
} else {
67+
FilterStrategy::Generic
68+
}
69+
}
70+
Some(8) => {
71+
if len <= BRANCHLESS_MAX_8B {
72+
FilterStrategy::Branchless
73+
} else {
74+
FilterStrategy::Generic
75+
}
76+
}
77+
Some(16) => {
78+
if len <= BRANCHLESS_MAX_16B {
79+
FilterStrategy::Branchless
80+
} else {
81+
FilterStrategy::Generic
82+
}
83+
}
84+
_ => FilterStrategy::Generic,
85+
}
86+
}
87+
88+
// =============================================================================
89+
// FILTER INSTANTIATION
90+
// =============================================================================
91+
92+
/// Creates the optimal static filter for the given array.
2993
pub(super) fn instantiate_static_filter(
3094
in_array: ArrayRef,
3195
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
32-
match in_array.data_type() {
33-
DataType::Int8 | DataType::UInt8 => make_bitmap_filter::<U8Config>(&in_array),
34-
DataType::Int16 | DataType::UInt16 => make_bitmap_filter::<U16Config>(&in_array),
35-
DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)),
36-
DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)),
37-
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
38-
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
39-
// Float primitive types (use ordered wrappers for Hash/Eq)
40-
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
41-
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
42-
_ => {
43-
/* fall through to generic implementation for unsupported types (Struct, etc.) */
44-
Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?))
45-
}
96+
use FilterStrategy::*;
97+
98+
let len = in_array.len();
99+
let dt = in_array.data_type();
100+
let strategy = select_strategy(dt, len);
101+
102+
match (dt, strategy) {
103+
// Bitmap filters for 1-byte and 2-byte types
104+
(_, Bitmap1B) => make_bitmap_filter::<U8Config>(&in_array),
105+
(_, Bitmap2B) => make_bitmap_filter::<U16Config>(&in_array),
106+
107+
// Branchless filters for small lists of primitives
108+
(_, Branchless) => dispatch_branchless(&in_array).ok_or_else(|| {
109+
exec_datafusion_err!(
110+
"Branchless strategy selected but no filter for {:?}",
111+
dt
112+
)
113+
})?,
114+
115+
// Fallback for larger primitive lists or complex types.
116+
(_, Generic) => match dt {
117+
DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)),
118+
DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)),
119+
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
120+
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
121+
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
122+
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
123+
_ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)),
124+
},
125+
}
126+
}
127+
128+
// =============================================================================
129+
// TYPE DISPATCH
130+
// =============================================================================
131+
132+
fn dispatch_branchless(
133+
arr: &ArrayRef,
134+
) -> Option<Result<Arc<dyn StaticFilter + Send + Sync>>> {
135+
// Dispatch to width-specific branchless filter.
136+
match arr.data_type().primitive_width() {
137+
Some(4) => Some(make_branchless_filter::<UInt32Type>(arr, 4)),
138+
Some(8) => Some(make_branchless_filter::<UInt64Type>(arr, 8)),
139+
Some(16) => Some(make_branchless_filter::<Decimal128Type>(arr, 16)),
140+
_ => None,
46141
}
47142
}

0 commit comments

Comments
 (0)