1818//! Implementation of `InList` expressions: [`InListExpr`]
1919
2020mod array_filter;
21+ mod strategy;
2122
2223use std:: any:: Any ;
2324use std:: fmt:: Debug ;
@@ -30,15 +31,15 @@ use crate::physical_expr::physical_exprs_bag_equal;
3031use arrow:: array:: * ;
3132use arrow:: buffer:: { BooleanBuffer , NullBuffer } ;
3233use arrow:: compute:: kernels:: boolean:: { not, or_kleene} ;
33- use arrow:: compute:: { take , SortOptions } ;
34+ use arrow:: compute:: SortOptions ;
3435use arrow:: datatypes:: * ;
3536use datafusion_common:: {
36- DFSchema , HashSet , Result , ScalarValue , assert_or_internal_err, exec_datafusion_err,
37- exec_err,
37+ DFSchema , Result , ScalarValue , assert_or_internal_err, exec_err,
3838} ;
3939use datafusion_expr:: { ColumnarValue , expr_vec_fmt} ;
4040
41- use array_filter:: { ArrayStaticFilter , StaticFilter } ;
41+ use array_filter:: StaticFilter ;
42+ use strategy:: instantiate_static_filter;
4243
4344/// InList
4445pub struct InListExpr {
@@ -58,347 +59,6 @@ impl Debug for InListExpr {
5859 }
5960}
6061
61- fn instantiate_static_filter (
62- in_array : ArrayRef ,
63- ) -> Result < Arc < dyn StaticFilter + Send + Sync > > {
64- match in_array. data_type ( ) {
65- // Integer primitive types
66- DataType :: Int8 => Ok ( Arc :: new ( Int8StaticFilter :: try_new ( & in_array) ?) ) ,
67- DataType :: Int16 => Ok ( Arc :: new ( Int16StaticFilter :: try_new ( & in_array) ?) ) ,
68- DataType :: Int32 => Ok ( Arc :: new ( Int32StaticFilter :: try_new ( & in_array) ?) ) ,
69- DataType :: Int64 => Ok ( Arc :: new ( Int64StaticFilter :: try_new ( & in_array) ?) ) ,
70- DataType :: UInt8 => Ok ( Arc :: new ( UInt8StaticFilter :: try_new ( & in_array) ?) ) ,
71- DataType :: UInt16 => Ok ( Arc :: new ( UInt16StaticFilter :: try_new ( & in_array) ?) ) ,
72- DataType :: UInt32 => Ok ( Arc :: new ( UInt32StaticFilter :: try_new ( & in_array) ?) ) ,
73- DataType :: UInt64 => Ok ( Arc :: new ( UInt64StaticFilter :: try_new ( & in_array) ?) ) ,
74- // Float primitive types (use ordered wrappers for Hash/Eq)
75- DataType :: Float32 => Ok ( Arc :: new ( Float32StaticFilter :: try_new ( & in_array) ?) ) ,
76- DataType :: Float64 => Ok ( Arc :: new ( Float64StaticFilter :: try_new ( & in_array) ?) ) ,
77- _ => {
78- /* fall through to generic implementation for unsupported types (Struct, etc.) */
79- Ok ( Arc :: new ( ArrayStaticFilter :: try_new ( in_array) ?) )
80- }
81- }
82- }
83-
84- /// Wrapper for f32 that implements Hash and Eq using bit comparison.
85- /// This treats NaN values as equal to each other when they have the same bit pattern.
86- #[ derive( Clone , Copy ) ]
87- struct OrderedFloat32 ( f32 ) ;
88-
89- impl Hash for OrderedFloat32 {
90- fn hash < H : Hasher > ( & self , state : & mut H ) {
91- self . 0 . to_ne_bytes ( ) . hash ( state) ;
92- }
93- }
94-
95- impl PartialEq for OrderedFloat32 {
96- fn eq ( & self , other : & Self ) -> bool {
97- self . 0 . to_bits ( ) == other. 0 . to_bits ( )
98- }
99- }
100-
101- impl Eq for OrderedFloat32 { }
102-
103- impl From < f32 > for OrderedFloat32 {
104- fn from ( v : f32 ) -> Self {
105- Self ( v)
106- }
107- }
108-
109- /// Wrapper for f64 that implements Hash and Eq using bit comparison.
110- /// This treats NaN values as equal to each other when they have the same bit pattern.
111- #[ derive( Clone , Copy ) ]
112- struct OrderedFloat64 ( f64 ) ;
113-
114- impl Hash for OrderedFloat64 {
115- fn hash < H : Hasher > ( & self , state : & mut H ) {
116- self . 0 . to_ne_bytes ( ) . hash ( state) ;
117- }
118- }
119-
120- impl PartialEq for OrderedFloat64 {
121- fn eq ( & self , other : & Self ) -> bool {
122- self . 0 . to_bits ( ) == other. 0 . to_bits ( )
123- }
124- }
125-
126- impl Eq for OrderedFloat64 { }
127-
128- impl From < f64 > for OrderedFloat64 {
129- fn from ( v : f64 ) -> Self {
130- Self ( v)
131- }
132- }
133-
134- // Macro to generate specialized StaticFilter implementations for primitive types
135- macro_rules! primitive_static_filter {
136- ( $Name: ident, $ArrowType: ty) => {
137- struct $Name {
138- null_count: usize ,
139- values: HashSet <<$ArrowType as ArrowPrimitiveType >:: Native >,
140- }
141-
142- impl $Name {
143- fn try_new( in_array: & ArrayRef ) -> Result <Self > {
144- let in_array = in_array
145- . as_primitive_opt:: <$ArrowType>( )
146- . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
147-
148- let mut values = HashSet :: with_capacity( in_array. len( ) ) ;
149- let null_count = in_array. null_count( ) ;
150-
151- for v in in_array. iter( ) . flatten( ) {
152- values. insert( v) ;
153- }
154-
155- Ok ( Self { null_count, values } )
156- }
157- }
158-
159- impl StaticFilter for $Name {
160- fn null_count( & self ) -> usize {
161- self . null_count
162- }
163-
164- fn contains( & self , v: & dyn Array , negated: bool ) -> Result <BooleanArray > {
165- // Handle dictionary arrays by recursing on the values
166- downcast_dictionary_array! {
167- v => {
168- let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
169- let result = take( & values_contains, v. keys( ) , None ) ?;
170- return Ok ( downcast_array( result. as_ref( ) ) )
171- }
172- _ => { }
173- }
174-
175- let v = v
176- . as_primitive_opt:: <$ArrowType>( )
177- . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
178-
179- let haystack_has_nulls = self . null_count > 0 ;
180-
181- let needle_values = v. values( ) ;
182- let needle_nulls = v. nulls( ) ;
183- let needle_has_nulls = v. null_count( ) > 0 ;
184-
185- // Truth table for `value [NOT] IN (set)` with SQL three-valued logic:
186- // ("-" means the value doesn't affect the result)
187- //
188- // | needle_null | haystack_null | negated | in set? | result |
189- // |-------------|---------------|---------|---------|--------|
190- // | true | - | false | - | null |
191- // | true | - | true | - | null |
192- // | false | true | false | yes | true |
193- // | false | true | false | no | null |
194- // | false | true | true | yes | false |
195- // | false | true | true | no | null |
196- // | false | false | false | yes | true |
197- // | false | false | false | no | false |
198- // | false | false | true | yes | false |
199- // | false | false | true | no | true |
200-
201- // Compute the "contains" result using collect_bool (fast batched approach)
202- // This ignores nulls - we handle them separately
203- let contains_buffer = if negated {
204- BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
205- !self . values. contains( & needle_values[ i] )
206- } )
207- } else {
208- BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
209- self . values. contains( & needle_values[ i] )
210- } )
211- } ;
212-
213- // Compute the null mask
214- // Output is null when:
215- // 1. needle value is null, OR
216- // 2. needle value is not in set AND haystack has nulls
217- let result_nulls = match ( needle_has_nulls, haystack_has_nulls) {
218- ( false , false ) => {
219- // No nulls anywhere
220- None
221- }
222- ( true , false ) => {
223- // Only needle has nulls - just use needle's null mask
224- needle_nulls. cloned( )
225- }
226- ( false , true ) => {
227- // Only haystack has nulls - result is null when value not in set
228- // Valid (not null) when original "in set" is true
229- // For NOT IN: contains_buffer = !original, so validity = !contains_buffer
230- let validity = if negated {
231- !& contains_buffer
232- } else {
233- contains_buffer. clone( )
234- } ;
235- Some ( NullBuffer :: new( validity) )
236- }
237- ( true , true ) => {
238- // Both have nulls - combine needle nulls with haystack-induced nulls
239- let needle_validity = needle_nulls. map( |n| n. inner( ) . clone( ) )
240- . unwrap_or_else( || BooleanBuffer :: new_set( needle_values. len( ) ) ) ;
241-
242- // Valid when original "in set" is true (see above)
243- let haystack_validity = if negated {
244- !& contains_buffer
245- } else {
246- contains_buffer. clone( )
247- } ;
248-
249- // Combined validity: valid only where both are valid
250- let combined_validity = & needle_validity & & haystack_validity;
251- Some ( NullBuffer :: new( combined_validity) )
252- }
253- } ;
254-
255- Ok ( BooleanArray :: new( contains_buffer, result_nulls) )
256- }
257- }
258- } ;
259- }
260-
261- // Generate specialized filters for all integer primitive types
262- primitive_static_filter ! ( Int8StaticFilter , Int8Type ) ;
263- primitive_static_filter ! ( Int16StaticFilter , Int16Type ) ;
264- primitive_static_filter ! ( Int32StaticFilter , Int32Type ) ;
265- primitive_static_filter ! ( Int64StaticFilter , Int64Type ) ;
266- primitive_static_filter ! ( UInt8StaticFilter , UInt8Type ) ;
267- primitive_static_filter ! ( UInt16StaticFilter , UInt16Type ) ;
268- primitive_static_filter ! ( UInt32StaticFilter , UInt32Type ) ;
269- primitive_static_filter ! ( UInt64StaticFilter , UInt64Type ) ;
270-
271- // Macro to generate specialized StaticFilter implementations for float types
272- // Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics
273- macro_rules! float_static_filter {
274- ( $Name: ident, $ArrowType: ty, $OrderedType: ty) => {
275- struct $Name {
276- null_count: usize ,
277- values: HashSet <$OrderedType>,
278- }
279-
280- impl $Name {
281- fn try_new( in_array: & ArrayRef ) -> Result <Self > {
282- let in_array = in_array
283- . as_primitive_opt:: <$ArrowType>( )
284- . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
285-
286- let mut values = HashSet :: with_capacity( in_array. len( ) ) ;
287- let null_count = in_array. null_count( ) ;
288-
289- for v in in_array. iter( ) . flatten( ) {
290- values. insert( <$OrderedType>:: from( v) ) ;
291- }
292-
293- Ok ( Self { null_count, values } )
294- }
295- }
296-
297- impl StaticFilter for $Name {
298- fn null_count( & self ) -> usize {
299- self . null_count
300- }
301-
302- fn contains( & self , v: & dyn Array , negated: bool ) -> Result <BooleanArray > {
303- // Handle dictionary arrays by recursing on the values
304- downcast_dictionary_array! {
305- v => {
306- let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
307- let result = take( & values_contains, v. keys( ) , None ) ?;
308- return Ok ( downcast_array( result. as_ref( ) ) )
309- }
310- _ => { }
311- }
312-
313- let v = v
314- . as_primitive_opt:: <$ArrowType>( )
315- . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
316-
317- let haystack_has_nulls = self . null_count > 0 ;
318-
319- let needle_values = v. values( ) ;
320- let needle_nulls = v. nulls( ) ;
321- let needle_has_nulls = v. null_count( ) > 0 ;
322-
323- // Truth table for `value [NOT] IN (set)` with SQL three-valued logic:
324- // ("-" means the value doesn't affect the result)
325- //
326- // | needle_null | haystack_null | negated | in set? | result |
327- // |-------------|---------------|---------|---------|--------|
328- // | true | - | false | - | null |
329- // | true | - | true | - | null |
330- // | false | true | false | yes | true |
331- // | false | true | false | no | null |
332- // | false | true | true | yes | false |
333- // | false | true | true | no | null |
334- // | false | false | false | yes | true |
335- // | false | false | false | no | false |
336- // | false | false | true | yes | false |
337- // | false | false | true | no | true |
338-
339- // Compute the "contains" result using collect_bool (fast batched approach)
340- // This ignores nulls - we handle them separately
341- let contains_buffer = if negated {
342- BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
343- !self . values. contains( & <$OrderedType>:: from( needle_values[ i] ) )
344- } )
345- } else {
346- BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
347- self . values. contains( & <$OrderedType>:: from( needle_values[ i] ) )
348- } )
349- } ;
350-
351- // Compute the null mask
352- // Output is null when:
353- // 1. needle value is null, OR
354- // 2. needle value is not in set AND haystack has nulls
355- let result_nulls = match ( needle_has_nulls, haystack_has_nulls) {
356- ( false , false ) => {
357- // No nulls anywhere
358- None
359- }
360- ( true , false ) => {
361- // Only needle has nulls - just use needle's null mask
362- needle_nulls. cloned( )
363- }
364- ( false , true ) => {
365- // Only haystack has nulls - result is null when value not in set
366- // Valid (not null) when original "in set" is true
367- // For NOT IN: contains_buffer = !original, so validity = !contains_buffer
368- let validity = if negated {
369- !& contains_buffer
370- } else {
371- contains_buffer. clone( )
372- } ;
373- Some ( NullBuffer :: new( validity) )
374- }
375- ( true , true ) => {
376- // Both have nulls - combine needle nulls with haystack-induced nulls
377- let needle_validity = needle_nulls. map( |n| n. inner( ) . clone( ) )
378- . unwrap_or_else( || BooleanBuffer :: new_set( needle_values. len( ) ) ) ;
379-
380- // Valid when original "in set" is true (see above)
381- let haystack_validity = if negated {
382- !& contains_buffer
383- } else {
384- contains_buffer. clone( )
385- } ;
386-
387- // Combined validity: valid only where both are valid
388- let combined_validity = & needle_validity & & haystack_validity;
389- Some ( NullBuffer :: new( combined_validity) )
390- }
391- } ;
392-
393- Ok ( BooleanArray :: new( contains_buffer, result_nulls) )
394- }
395- }
396- } ;
397- }
398-
399- // Generate specialized filters for float types using ordered wrappers
400- float_static_filter ! ( Float32StaticFilter , Float32Type , OrderedFloat32 ) ;
401- float_static_filter ! ( Float64StaticFilter , Float64Type , OrderedFloat64 ) ;
40262
40363/// Evaluates the list of expressions into an array, flattening any dictionaries
40464fn evaluate_list (
0 commit comments