Skip to content

Commit 2b20c45

Browse files
perf(in_list): add BranchlessFilter for small IN lists
Adds BranchlessFilter<T, N> - a const-generic filter that unrolls membership checks into a fixed-size OR-chain comparison. For small lists (≤16 elements), this outperforms hash lookups due to: - No branching (uses bitwise OR to combine comparisons) - Better CPU pipelining - No hash computation overhead Strategy selection thresholds (tuned via benchmarks): - 4-byte types (Int32, Float32): branchless up to 16 elements - 8-byte types (Int64, Float64): branchless up to 16 elements - 16-byte types (Decimal128): branchless up to 4 elements
1 parent e6a8d57 commit 2b20c45

3 files changed

Lines changed: 345 additions & 21 deletions

File tree

datafusion/physical-expr/src/expressions/in_list/primitive.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,92 @@ where
248248
))
249249
}
250250
}
251+
252+
// =============================================================================
253+
// BRANCHLESS FILTER (Const Generic for Small Lists)
254+
// =============================================================================
255+
256+
/// A branchless filter for very small IN lists (0-16 elements).
257+
///
258+
/// Uses const generics to unroll the membership check into a fixed-size
259+
/// comparison chain, outperforming hash lookups for small lists due to:
260+
/// - No branching (uses bitwise OR to combine comparisons)
261+
/// - Better CPU pipelining
262+
/// - No hash computation overhead
263+
pub(crate) struct BranchlessFilter<T: ArrowPrimitiveType, const N: usize> {
264+
null_count: usize,
265+
values: [T::Native; N],
266+
}
267+
268+
impl<T: ArrowPrimitiveType, const N: usize> BranchlessFilter<T, N>
269+
where
270+
T::Native: Copy + PartialEq,
271+
{
272+
/// Try to create a branchless filter if the array has exactly N non-null values.
273+
pub(crate) fn try_new(in_array: &ArrayRef) -> Option<Result<Self>> {
274+
let in_array = in_array.as_primitive_opt::<T>()?;
275+
let non_null_count = in_array.len() - in_array.null_count();
276+
if non_null_count != N {
277+
return None;
278+
}
279+
let values: Vec<_> = in_array.iter().flatten().collect();
280+
// Use default_value() from ArrowPrimitiveType trait instead of Default::default()
281+
let mut arr = [T::default_value(); N];
282+
arr.copy_from_slice(&values);
283+
Some(Ok(Self {
284+
null_count: in_array.null_count(),
285+
values: arr,
286+
}))
287+
}
288+
289+
/// Branchless membership check using OR-chain.
290+
#[inline(always)]
291+
fn check(&self, needle: T::Native) -> bool {
292+
self.values
293+
.iter()
294+
.fold(false, |acc, &v| acc | (v == needle))
295+
}
296+
297+
/// Check membership using a raw values slice (zero-copy path for type reinterpretation).
298+
#[inline]
299+
pub(crate) fn contains_slice(
300+
&self,
301+
values: &[T::Native],
302+
nulls: Option<&arrow::buffer::NullBuffer>,
303+
negated: bool,
304+
) -> BooleanArray {
305+
build_in_list_result(
306+
values.len(),
307+
nulls,
308+
self.null_count > 0,
309+
negated,
310+
|i| self.check(unsafe { *values.get_unchecked(i) }),
311+
)
312+
}
313+
}
314+
315+
impl<T: ArrowPrimitiveType, const N: usize> StaticFilter for BranchlessFilter<T, N>
316+
where
317+
T::Native: Copy + PartialEq + Send + Sync,
318+
{
319+
fn null_count(&self) -> usize {
320+
self.null_count
321+
}
322+
323+
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
324+
handle_dictionary!(self, v, negated);
325+
let v = v.as_primitive_opt::<T>().ok_or_else(|| {
326+
exec_datafusion_err!("Failed to downcast array to primitive type")
327+
})?;
328+
let input_values = v.values();
329+
Ok(build_in_list_result(
330+
v.len(),
331+
v.nulls(),
332+
self.null_count > 0,
333+
negated,
334+
// SAFETY: i is in bounds since we iterate 0..v.len()
335+
#[inline(always)]
336+
|i| self.check(unsafe { *input_values.get_unchecked(i) }),
337+
))
338+
}
339+
}

datafusion/physical-expr/src/expressions/in_list/strategy.rs

Lines changed: 198 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,39 +15,217 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! Filter strategy selection for InList expressions
18+
//! Filter selection strategy for InList expressions
19+
//!
20+
//! Selects the optimal lookup strategy based on data type and list size:
21+
//!
22+
//! - 1-byte types (Int8/UInt8): bitmap (32 bytes, O(1) bit test)
23+
//! - 2-byte types (Int16/UInt16): bitmap (8 KB, O(1) bit test)
24+
//! - 4-byte types (Int32/Float32): branchless (≤16) or hash (>16)
25+
//! - 8-byte types (Int64/Float64): branchless (≤16) or hash (>16)
26+
//! - 16-byte types (Decimal128): branchless (≤4) or hash (>4)
27+
//! - Other types: generic ArrayStaticFilter
1928
29+
use std::hash::Hash;
2030
use std::sync::Arc;
2131

2232
use arrow::array::ArrayRef;
2333
use arrow::datatypes::*;
2434
use datafusion_common::Result;
2535

26-
use super::array_filter::{ArrayStaticFilter, StaticFilter};
36+
use super::array_filter::ArrayStaticFilter;
37+
use super::array_filter::StaticFilter;
2738
use super::primitive::{PrimitiveFilter, U8Config, U16Config};
28-
use super::transform::{make_bitmap_filter, make_primitive_filter};
39+
use super::transform::{make_bitmap_filter, make_branchless_filter, make_primitive_filter};
2940

41+
// =============================================================================
42+
// LOOKUP STRATEGY THRESHOLDS (tuned via microbenchmarks)
43+
// =============================================================================
44+
//
45+
// Based on minimum batch time (8192 lookups per batch):
46+
// - Int8 (1 byte): BITMAP (32 bytes, always fastest)
47+
// - Int16 (2 bytes): BITMAP (8 KB, always fastest)
48+
// - Int32 (4 bytes): branchless up to 16, then hashset
49+
// - Int64 (8 bytes): branchless up to 16, then hashset
50+
// - Int128 (16 bytes): branchless up to 4, then hashset
51+
// - Other types: hashset (via ArrayStaticFilter)
52+
//
53+
// NOTE: Binary search and linear scan were benchmarked but consistently
54+
// lost to the strategies above at all tested list sizes.
55+
56+
/// Maximum list size for branchless lookup on 4-byte primitives (Int32, UInt32, Float32).
57+
const BRANCHLESS_MAX_4B: usize = 16;
58+
59+
/// Maximum list size for branchless lookup on 8-byte primitives (Int64, UInt64, Float64).
60+
const BRANCHLESS_MAX_8B: usize = 16;
61+
62+
/// Maximum list size for branchless lookup on 16-byte types (Decimal128).
63+
const BRANCHLESS_MAX_16B: usize = 4;
64+
65+
// =============================================================================
66+
// FILTER STRATEGY SELECTION
67+
// =============================================================================
68+
69+
/// The lookup strategy to use for a given data type and list size.
70+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71+
enum FilterStrategy {
72+
/// Bitmap filter for u8/u16 - O(1) bit test, always fastest for these types.
73+
Bitmap1B,
74+
Bitmap2B,
75+
/// Branchless OR-chain for small lists.
76+
Branchless,
77+
/// HashSet for larger lists.
78+
Hashed,
79+
/// Generic ArrayStaticFilter fallback.
80+
Generic,
81+
}
82+
83+
/// Determines the optimal lookup strategy based on data type and list size.
84+
///
85+
/// For 1-byte and 2-byte types, bitmap is always used (benchmarks show it's
86+
/// faster than both branchless and hashed at all list sizes).
87+
/// For larger types, cutoffs are tuned per byte-width.
88+
fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy {
89+
match dt.primitive_width() {
90+
Some(1) => FilterStrategy::Bitmap1B,
91+
Some(2) => FilterStrategy::Bitmap2B,
92+
Some(4) => {
93+
if len <= BRANCHLESS_MAX_4B {
94+
FilterStrategy::Branchless
95+
} else {
96+
FilterStrategy::Hashed
97+
}
98+
}
99+
Some(8) => {
100+
if len <= BRANCHLESS_MAX_8B {
101+
FilterStrategy::Branchless
102+
} else {
103+
FilterStrategy::Hashed
104+
}
105+
}
106+
Some(16) => {
107+
if len <= BRANCHLESS_MAX_16B {
108+
FilterStrategy::Branchless
109+
} else {
110+
FilterStrategy::Hashed
111+
}
112+
}
113+
_ => FilterStrategy::Generic,
114+
}
115+
}
116+
117+
// =============================================================================
118+
// FILTER INSTANTIATION
119+
// =============================================================================
120+
121+
/// Creates the optimal static filter for the given array.
122+
///
123+
/// This is the main entry point for filter creation. It analyzes the array's
124+
/// data type and size to select the best lookup strategy.
30125
pub(crate) fn instantiate_static_filter(
31126
in_array: ArrayRef,
32127
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
33-
match in_array.data_type() {
34-
// 1-byte types: use bitmap (256 bits = 32 bytes)
35-
DataType::Int8 | DataType::UInt8 => make_bitmap_filter::<U8Config>(&in_array),
36-
// 2-byte types: use bitmap (65536 bits = 8 KB)
37-
DataType::Int16 | DataType::UInt16 => make_bitmap_filter::<U16Config>(&in_array),
38-
// 4-byte integer types
39-
DataType::Int32 => Ok(Arc::new(PrimitiveFilter::<Int32Type>::try_new(&in_array)?)),
40-
DataType::UInt32 => Ok(Arc::new(PrimitiveFilter::<UInt32Type>::try_new(&in_array)?)),
41-
// 8-byte integer types
42-
DataType::Int64 => Ok(Arc::new(PrimitiveFilter::<Int64Type>::try_new(&in_array)?)),
43-
DataType::UInt64 => Ok(Arc::new(PrimitiveFilter::<UInt64Type>::try_new(&in_array)?)),
44-
// Float types: reinterpret as unsigned integers (same bit pattern = equal)
45-
DataType::Float32 => make_primitive_filter::<UInt32Type>(&in_array),
46-
DataType::Float64 => make_primitive_filter::<UInt64Type>(&in_array),
47-
_ => {
48-
/* fall through to generic implementation for unsupported types (Struct, etc.) */
49-
Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?))
128+
let len = in_array.len();
129+
let dt = in_array.data_type();
130+
131+
match select_strategy(dt, len) {
132+
FilterStrategy::Bitmap1B => dispatch_bitmap_u8(&in_array),
133+
FilterStrategy::Bitmap2B => dispatch_bitmap_u16(&in_array),
134+
FilterStrategy::Branchless => dispatch_filter(&in_array, dispatch_branchless),
135+
FilterStrategy::Hashed => dispatch_filter(&in_array, dispatch_hashed),
136+
FilterStrategy::Generic => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)),
137+
}
138+
}
139+
140+
/// Generic filter dispatcher with fallback to ArrayStaticFilter.
141+
fn dispatch_filter<F>(
142+
in_array: &ArrayRef,
143+
dispatch: F,
144+
) -> Result<Arc<dyn StaticFilter + Send + Sync>>
145+
where
146+
F: Fn(&ArrayRef) -> Option<Result<Arc<dyn StaticFilter + Send + Sync>>>,
147+
{
148+
dispatch(in_array).unwrap_or_else(|| {
149+
Ok(Arc::new(ArrayStaticFilter::try_new(Arc::clone(in_array))?))
150+
})
151+
}
152+
153+
fn dispatch_bitmap_u8(
154+
in_array: &ArrayRef,
155+
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
156+
make_bitmap_filter::<U8Config>(in_array)
157+
}
158+
159+
fn dispatch_bitmap_u16(
160+
in_array: &ArrayRef,
161+
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
162+
make_bitmap_filter::<U16Config>(in_array)
163+
}
164+
165+
// =============================================================================
166+
// TYPE DISPATCH
167+
// =============================================================================
168+
169+
/// Dispatch macro that routes primitive types by width to the appropriate UInt type.
170+
///
171+
/// All primitive types (Int*, UInt*, Float*, Timestamp*, Date*, Duration*, etc.) are
172+
/// automatically dispatched based on their width. The reinterpret function handles
173+
/// the fast path when source type already matches the destination UInt type.
174+
macro_rules! dispatch_primitive {
175+
($arr:expr, $reinterpret:ident) => {
176+
match $arr.data_type().primitive_width() {
177+
Some(1) => Some($reinterpret::<UInt8Type>($arr)),
178+
Some(2) => Some($reinterpret::<UInt16Type>($arr)),
179+
Some(4) => Some($reinterpret::<UInt32Type>($arr)),
180+
Some(8) => Some($reinterpret::<UInt64Type>($arr)),
181+
Some(16) => Some($reinterpret::<Decimal128Type>($arr)),
182+
_ => None,
50183
}
184+
};
185+
}
186+
187+
fn dispatch_branchless(
188+
arr: &ArrayRef,
189+
) -> Option<Result<Arc<dyn StaticFilter + Send + Sync>>> {
190+
fn make<D: ArrowPrimitiveType + 'static>(
191+
arr: &ArrayRef,
192+
) -> Result<Arc<dyn StaticFilter + Send + Sync>>
193+
where
194+
D::Native: Copy + PartialEq + Send + Sync + 'static,
195+
{
196+
make_branchless_filter::<D>(arr)
51197
}
198+
dispatch_primitive!(arr, make)
52199
}
53200

201+
fn dispatch_hashed(
202+
arr: &ArrayRef,
203+
) -> Option<Result<Arc<dyn StaticFilter + Send + Sync>>> {
204+
// Fast path: create PrimitiveFilter directly for common hashable types
205+
macro_rules! direct_filter {
206+
($T:ty) => {
207+
return Some(
208+
PrimitiveFilter::<$T>::try_new(arr)
209+
.map(|f| Arc::new(f) as Arc<dyn StaticFilter + Send + Sync>),
210+
)
211+
};
212+
}
213+
match arr.data_type() {
214+
DataType::Int32 => direct_filter!(Int32Type),
215+
DataType::Int64 => direct_filter!(Int64Type),
216+
DataType::UInt32 => direct_filter!(UInt32Type),
217+
DataType::UInt64 => direct_filter!(UInt64Type),
218+
_ => {}
219+
}
220+
221+
// For other types (Float32, Float64, Timestamp, etc.), reinterpret to UInt
222+
fn make<D: ArrowPrimitiveType + 'static>(
223+
arr: &ArrayRef,
224+
) -> Result<Arc<dyn StaticFilter + Send + Sync>>
225+
where
226+
D::Native: Hash + Eq + Send + Sync + 'static,
227+
{
228+
make_primitive_filter::<D>(arr)
229+
}
230+
dispatch_primitive!(arr, make)
231+
}

0 commit comments

Comments
 (0)