Skip to content

Commit d2777b1

Browse files
Implement Branchless Filter for small primitive lists
Adds a const-generic unrolled comparison chain that avoids CPU branching. Outperforms hash lookups for very small lists. Triggers for primitives when list size <= 32 (4-byte), 16 (8-byte), or 4 (16-byte).
1 parent 48afee1 commit d2777b1

3 files changed

Lines changed: 343 additions & 18 deletions

File tree

datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,91 @@ impl<C: BitmapFilterConfig> StaticFilter for BitmapFilter<C> {
173173
}
174174
}
175175

176+
// =============================================================================
177+
// BRANCHLESS FILTER (Const Generic for Small Lists)
178+
// =============================================================================
179+
180+
/// A branchless filter for very small IN lists (0-16 elements).
181+
///
182+
/// Uses const generics to unroll the membership check into a fixed-size
183+
/// comparison chain, outperforming hash lookups for small lists due to:
184+
/// - No branching (uses bitwise OR to combine comparisons)
185+
/// - Better CPU pipelining
186+
/// - No hash computation overhead
187+
pub(crate) struct BranchlessFilter<T: ArrowPrimitiveType, const N: usize> {
188+
null_count: usize,
189+
values: [T::Native; N],
190+
}
191+
192+
impl<T: ArrowPrimitiveType, const N: usize> BranchlessFilter<T, N>
193+
where
194+
T::Native: Copy + PartialEq,
195+
{
196+
/// Try to create a branchless filter if the array has exactly N non-null values.
197+
pub(crate) fn try_new(in_array: &ArrayRef) -> Option<Result<Self>> {
198+
let in_array = in_array.as_primitive_opt::<T>()?;
199+
let non_null_count = in_array.len() - in_array.null_count();
200+
if non_null_count != N {
201+
return None;
202+
}
203+
let values: Vec<_> = in_array.iter().flatten().collect();
204+
// Use default_value() from ArrowPrimitiveType trait instead of Default::default()
205+
let mut arr = [T::default_value(); N];
206+
arr.copy_from_slice(&values);
207+
Some(Ok(Self {
208+
null_count: in_array.null_count(),
209+
values: arr,
210+
}))
211+
}
212+
213+
/// Branchless membership check using OR-chain.
214+
#[inline(always)]
215+
fn check(&self, needle: T::Native) -> bool {
216+
self.values
217+
.iter()
218+
.fold(false, |acc, &v| acc | (v == needle))
219+
}
220+
221+
/// Check membership using a raw values slice (zero-copy path for type reinterpretation).
222+
#[inline]
223+
pub(crate) fn contains_slice(
224+
&self,
225+
values: &[T::Native],
226+
nulls: Option<&NullBuffer>,
227+
negated: bool,
228+
) -> BooleanArray {
229+
build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| {
230+
self.check(unsafe { *values.get_unchecked(i) })
231+
})
232+
}
233+
}
234+
235+
impl<T: ArrowPrimitiveType, const N: usize> StaticFilter for BranchlessFilter<T, N>
236+
where
237+
T::Native: Copy + PartialEq + Send + Sync,
238+
{
239+
fn null_count(&self) -> usize {
240+
self.null_count
241+
}
242+
243+
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
244+
handle_dictionary!(self, v, negated);
245+
let v = v.as_primitive_opt::<T>().ok_or_else(|| {
246+
exec_datafusion_err!("Failed to downcast array to primitive type")
247+
})?;
248+
let input_values = v.values();
249+
Ok(build_in_list_result(
250+
v.len(),
251+
v.nulls(),
252+
self.null_count > 0,
253+
negated,
254+
// SAFETY: i is in bounds since we iterate 0..v.len()
255+
#[inline(always)]
256+
|i| self.check(unsafe { *input_values.get_unchecked(i) }),
257+
))
258+
}
259+
}
260+
176261
// =============================================================================
177262
// LEGACY FILTERS (to be replaced by optimized ones in subsequent commits)
178263
// =============================================================================

datafusion/physical-expr/src/expressions/in_list/strategy.rs

Lines changed: 111 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,125 @@
2020
use std::sync::Arc;
2121

2222
use arrow::array::ArrayRef;
23-
use arrow::datatypes::DataType;
24-
use datafusion_common::Result;
23+
use arrow::datatypes::*;
24+
use datafusion_common::{Result, exec_datafusion_err};
2525

2626
use super::nested_filter::NestedTypeFilter;
2727
use super::primitive_filter::*;
2828
use super::static_filter::StaticFilter;
29-
use super::transform::make_bitmap_filter;
29+
use super::transform::{make_bitmap_filter, make_branchless_filter};
30+
31+
// =============================================================================
32+
// LOOKUP STRATEGY THRESHOLDS (tuned via microbenchmarks)
33+
// =============================================================================
34+
35+
/// Maximum list size for branchless lookup on 4-byte primitives (Int32, UInt32, Float32).
36+
const BRANCHLESS_MAX_4B: usize = 32;
37+
38+
/// Maximum list size for branchless lookup on 8-byte primitives (Int64, UInt64, Float64).
39+
const BRANCHLESS_MAX_8B: usize = 16;
40+
41+
/// Maximum list size for branchless lookup on 16-byte types (Decimal128).
42+
const BRANCHLESS_MAX_16B: usize = 4;
43+
44+
// =============================================================================
45+
// FILTER STRATEGY SELECTION
46+
// =============================================================================
47+
48+
/// The lookup strategy to use for a given data type and list size.
49+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50+
enum FilterStrategy {
51+
/// Bitmap filter for u8/u16 - O(1) bit test, always fastest for these types.
52+
Bitmap1B,
53+
Bitmap2B,
54+
/// Branchless OR-chain for small lists.
55+
Branchless,
56+
/// Generic ArrayStaticFilter fallback.
57+
Generic,
58+
}
59+
60+
/// Determines the optimal lookup strategy based on data type and list size.
61+
fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy {
62+
match dt.primitive_width() {
63+
Some(1) => FilterStrategy::Bitmap1B,
64+
Some(2) => FilterStrategy::Bitmap2B,
65+
Some(4) => {
66+
if len <= BRANCHLESS_MAX_4B {
67+
FilterStrategy::Branchless
68+
} else {
69+
FilterStrategy::Generic
70+
}
71+
}
72+
Some(8) => {
73+
if len <= BRANCHLESS_MAX_8B {
74+
FilterStrategy::Branchless
75+
} else {
76+
FilterStrategy::Generic
77+
}
78+
}
79+
Some(16) => {
80+
if len <= BRANCHLESS_MAX_16B {
81+
FilterStrategy::Branchless
82+
} else {
83+
FilterStrategy::Generic
84+
}
85+
}
86+
_ => FilterStrategy::Generic,
87+
}
88+
}
89+
90+
// =============================================================================
91+
// FILTER INSTANTIATION
92+
// =============================================================================
3093

3194
/// Creates the optimal static filter for the given array.
32-
///
33-
/// This is the main entry point for filter creation. It analyzes the array's
34-
/// data type and size to select the best lookup strategy.
3595
pub(crate) fn instantiate_static_filter(
3696
in_array: ArrayRef,
3797
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
38-
match in_array.data_type() {
39-
DataType::Int8 | DataType::UInt8 => make_bitmap_filter::<U8Config>(&in_array),
40-
DataType::Int16 | DataType::UInt16 => make_bitmap_filter::<U16Config>(&in_array),
41-
DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)),
42-
DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)),
43-
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
44-
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
45-
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
46-
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
47-
_ => Ok(Arc::new(NestedTypeFilter::try_new(in_array)?)),
98+
use FilterStrategy::*;
99+
100+
let len = in_array.len();
101+
let dt = in_array.data_type();
102+
let strategy = select_strategy(dt, len);
103+
104+
match (dt, strategy) {
105+
// Bitmap filters for 1-byte and 2-byte types
106+
(_, Bitmap1B) => make_bitmap_filter::<U8Config>(&in_array),
107+
(_, Bitmap2B) => make_bitmap_filter::<U16Config>(&in_array),
108+
109+
// Branchless filters for small lists of primitives
110+
(_, Branchless) => dispatch_branchless(&in_array).ok_or_else(|| {
111+
exec_datafusion_err!(
112+
"Branchless strategy selected but no filter for {:?}",
113+
dt
114+
)
115+
})?,
116+
117+
// Fallback for larger primitive lists (Legacy HashSet) or complex types (NestedTypeFilter)
118+
(_, Generic) => match dt {
119+
DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)),
120+
DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)),
121+
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
122+
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
123+
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
124+
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
125+
_ => Ok(Arc::new(NestedTypeFilter::try_new(in_array)?)),
126+
},
127+
}
128+
}
129+
130+
// =============================================================================
131+
// TYPE DISPATCH
132+
// =============================================================================
133+
134+
fn dispatch_branchless(
135+
arr: &ArrayRef,
136+
) -> Option<Result<Arc<dyn StaticFilter + Send + Sync>>> {
137+
// Dispatch to width-specific branchless filter.
138+
match arr.data_type().primitive_width() {
139+
Some(4) => Some(make_branchless_filter::<UInt32Type>(arr, 4)),
140+
Some(8) => Some(make_branchless_filter::<UInt64Type>(arr, 8)),
141+
Some(16) => Some(make_branchless_filter::<Decimal128Type>(arr, 16)),
142+
_ => None,
48143
}
49144
}

0 commit comments

Comments
 (0)