@@ -175,30 +175,48 @@ impl<C: BitmapFilterConfig> StaticFilter for BitmapFilter<C> {
175175}
176176
177177// =============================================================================
178- // PRIMITIVE FILTER (Hash-based )
178+ // BRANCHLESS FILTER (Const Generic for Small Lists )
179179// =============================================================================
180180
181- /// Hash-based filter for primitive types with larger IN lists.
182- pub ( crate ) struct PrimitiveFilter < T : ArrowPrimitiveType > {
181+ /// A branchless filter for very small IN lists (0-16 elements).
182+ ///
183+ /// Uses const generics to unroll the membership check into a fixed-size
184+ /// comparison chain, outperforming hash lookups for small lists due to:
185+ /// - No branching (uses bitwise OR to combine comparisons)
186+ /// - Better CPU pipelining
187+ /// - No hash computation overhead
188+ pub ( crate ) struct BranchlessFilter < T : ArrowPrimitiveType , const N : usize > {
183189 null_count : usize ,
184- set : HashSet < T :: Native > ,
190+ values : [ T :: Native ; N ] ,
185191}
186192
187- impl < T : ArrowPrimitiveType > PrimitiveFilter < T >
193+ impl < T : ArrowPrimitiveType , const N : usize > BranchlessFilter < T , N >
188194where
189- T :: Native : Hash + Eq ,
195+ T :: Native : Copy + PartialEq ,
190196{
191- pub ( crate ) fn try_new ( in_array : & ArrayRef ) -> Result < Self > {
192- let arr = in_array. as_primitive_opt :: < T > ( ) . ok_or_else ( || {
193- exec_datafusion_err ! (
194- "PrimitiveFilter: expected {} array" ,
195- std:: any:: type_name:: <T >( )
196- )
197- } ) ?;
198- Ok ( Self {
199- null_count : arr. null_count ( ) ,
200- set : arr. iter ( ) . flatten ( ) . collect ( ) ,
201- } )
197+ /// Try to create a branchless filter if the array has exactly N non-null values.
198+ pub ( crate ) fn try_new ( in_array : & ArrayRef ) -> Option < Result < Self > > {
199+ let in_array = in_array. as_primitive_opt :: < T > ( ) ?;
200+ let non_null_count = in_array. len ( ) - in_array. null_count ( ) ;
201+ if non_null_count != N {
202+ return None ;
203+ }
204+ let values: Vec < _ > = in_array. iter ( ) . flatten ( ) . collect ( ) ;
205+ // Use default_value() from ArrowPrimitiveType trait instead of Default::default()
206+ let mut arr = [ T :: default_value ( ) ; N ] ;
207+ arr. copy_from_slice ( & values) ;
208+ Some ( Ok ( Self {
209+ null_count : in_array. null_count ( ) ,
210+ values : arr,
211+ } ) )
212+ }
213+
214+ /// Branchless membership check using OR-chain.
215+ #[ inline( always) ]
216+ fn check ( & self , needle : T :: Native ) -> bool {
217+ self . values
218+ . iter ( )
219+ . fold ( false , |acc, & v| acc | ( v == needle) )
202220 }
203221
204222 /// Check membership using a raw values slice (zero-copy path for type reinterpretation).
@@ -214,16 +232,14 @@ where
214232 nulls,
215233 self . null_count > 0 ,
216234 negated,
217- // SAFETY: i is in bounds since we iterate 0..values.len()
218- |i| self . set . contains ( unsafe { values. get_unchecked ( i) } ) ,
235+ |i| self . check ( unsafe { * values. get_unchecked ( i) } ) ,
219236 )
220237 }
221238}
222239
223- impl < T > StaticFilter for PrimitiveFilter < T >
240+ impl < T : ArrowPrimitiveType , const N : usize > StaticFilter for BranchlessFilter < T , N >
224241where
225- T : ArrowPrimitiveType + ' static ,
226- T :: Native : Hash + Eq + Send + Sync + ' static ,
242+ T :: Native : Copy + PartialEq + Send + Sync ,
227243{
228244 fn null_count ( & self ) -> usize {
229245 self . null_count
@@ -232,66 +248,46 @@ where
232248 fn contains ( & self , v : & dyn Array , negated : bool ) -> Result < BooleanArray > {
233249 handle_dictionary ! ( self , v, negated) ;
234250 let v = v. as_primitive_opt :: < T > ( ) . ok_or_else ( || {
235- exec_datafusion_err ! (
236- "PrimitiveFilter: expected {} array" ,
237- std:: any:: type_name:: <T >( )
238- )
251+ exec_datafusion_err ! ( "Failed to downcast array to primitive type" )
239252 } ) ?;
240- let values = v. values ( ) ;
253+ let input_values = v. values ( ) ;
241254 Ok ( build_in_list_result (
242255 v. len ( ) ,
243256 v. nulls ( ) ,
244257 self . null_count > 0 ,
245258 negated,
246259 // SAFETY: i is in bounds since we iterate 0..v.len()
247- |i| self . set . contains ( unsafe { values. get_unchecked ( i) } ) ,
260+ #[ inline( always) ]
261+ |i| self . check ( unsafe { * input_values. get_unchecked ( i) } ) ,
248262 ) )
249263 }
250264}
251265
252266// =============================================================================
253- // BRANCHLESS FILTER (Const Generic for Small Lists )
267+ // PRIMITIVE FILTER (Hash-based )
254268// =============================================================================
255269
256- /// A branchless filter for very small IN lists (0-16 elements).
257- ///
258- /// Uses const generics to unroll the membership check into a fixed-size
259- /// comparison chain, outperforming hash lookups for small lists due to:
260- /// - No branching (uses bitwise OR to combine comparisons)
261- /// - Better CPU pipelining
262- /// - No hash computation overhead
263- pub ( crate ) struct BranchlessFilter < T : ArrowPrimitiveType , const N : usize > {
270+ /// Hash-based filter for primitive types with larger IN lists.
271+ pub ( crate ) struct PrimitiveFilter < T : ArrowPrimitiveType > {
264272 null_count : usize ,
265- values : [ T :: Native ; N ] ,
273+ set : HashSet < T :: Native > ,
266274}
267275
268- impl < T : ArrowPrimitiveType , const N : usize > BranchlessFilter < T , N >
276+ impl < T : ArrowPrimitiveType > PrimitiveFilter < T >
269277where
270- T :: Native : Copy + PartialEq ,
278+ T :: Native : Hash + Eq ,
271279{
272- /// Try to create a branchless filter if the array has exactly N non-null values.
273- pub ( crate ) fn try_new ( in_array : & ArrayRef ) -> Option < Result < Self > > {
274- let in_array = in_array. as_primitive_opt :: < T > ( ) ?;
275- let non_null_count = in_array. len ( ) - in_array. null_count ( ) ;
276- if non_null_count != N {
277- return None ;
278- }
279- let values: Vec < _ > = in_array. iter ( ) . flatten ( ) . collect ( ) ;
280- // Use default_value() from ArrowPrimitiveType trait instead of Default::default()
281- let mut arr = [ T :: default_value ( ) ; N ] ;
282- arr. copy_from_slice ( & values) ;
283- Some ( Ok ( Self {
284- null_count : in_array. null_count ( ) ,
285- values : arr,
286- } ) )
287- }
288-
289- /// Branchless membership check using OR-chain.
290- #[ inline( always) ]
291- fn check ( & self , needle : T :: Native ) -> bool {
292- self . values
293- . iter ( )
294- . fold ( false , |acc, & v| acc | ( v == needle) )
280+ pub ( crate ) fn try_new ( in_array : & ArrayRef ) -> Result < Self > {
281+ let arr = in_array. as_primitive_opt :: < T > ( ) . ok_or_else ( || {
282+ exec_datafusion_err ! (
283+ "PrimitiveFilter: expected {} array" ,
284+ std:: any:: type_name:: <T >( )
285+ )
286+ } ) ?;
287+ Ok ( Self {
288+ null_count : arr. null_count ( ) ,
289+ set : arr. iter ( ) . flatten ( ) . collect ( ) ,
290+ } )
295291 }
296292
297293 /// Check membership using a raw values slice (zero-copy path for type reinterpretation).
@@ -307,14 +303,16 @@ where
307303 nulls,
308304 self . null_count > 0 ,
309305 negated,
310- |i| self . check ( unsafe { * values. get_unchecked ( i) } ) ,
306+ // SAFETY: i is in bounds since we iterate 0..values.len()
307+ |i| self . set . contains ( unsafe { values. get_unchecked ( i) } ) ,
311308 )
312309 }
313310}
314311
315- impl < T : ArrowPrimitiveType , const N : usize > StaticFilter for BranchlessFilter < T , N >
312+ impl < T > StaticFilter for PrimitiveFilter < T >
316313where
317- T :: Native : Copy + PartialEq + Send + Sync ,
314+ T : ArrowPrimitiveType + ' static ,
315+ T :: Native : Hash + Eq + Send + Sync + ' static ,
318316{
319317 fn null_count ( & self ) -> usize {
320318 self . null_count
@@ -323,17 +321,19 @@ where
323321 fn contains ( & self , v : & dyn Array , negated : bool ) -> Result < BooleanArray > {
324322 handle_dictionary ! ( self , v, negated) ;
325323 let v = v. as_primitive_opt :: < T > ( ) . ok_or_else ( || {
326- exec_datafusion_err ! ( "Failed to downcast array to primitive type" )
324+ exec_datafusion_err ! (
325+ "PrimitiveFilter: expected {} array" ,
326+ std:: any:: type_name:: <T >( )
327+ )
327328 } ) ?;
328- let input_values = v. values ( ) ;
329+ let values = v. values ( ) ;
329330 Ok ( build_in_list_result (
330331 v. len ( ) ,
331332 v. nulls ( ) ,
332333 self . null_count > 0 ,
333334 negated,
334335 // SAFETY: i is in bounds since we iterate 0..v.len()
335- #[ inline( always) ]
336- |i| self . check ( unsafe { * input_values. get_unchecked ( i) } ) ,
336+ |i| self . set . contains ( unsafe { values. get_unchecked ( i) } ) ,
337337 ) )
338338 }
339339}
0 commit comments