1717
1818//! Filter strategy selection for InList expressions
1919
20- use std:: hash:: { Hash , Hasher } ;
2120use std:: sync:: Arc ;
2221
23- use arrow:: array:: * ;
24- use arrow:: buffer:: { BooleanBuffer , NullBuffer } ;
25- use arrow:: compute:: take;
22+ use arrow:: array:: ArrayRef ;
2623use arrow:: datatypes:: * ;
27- use datafusion_common:: { HashSet , Result , exec_datafusion_err } ;
24+ use datafusion_common:: Result ;
2825
2926use super :: array_filter:: { ArrayStaticFilter , StaticFilter } ;
3027use super :: primitive:: { PrimitiveFilter , U8Config , U16Config } ;
31- use super :: transform:: make_bitmap_filter;
28+ use super :: transform:: { make_bitmap_filter, make_primitive_filter } ;
3229
3330pub ( crate ) fn instantiate_static_filter (
3431 in_array : ArrayRef ,
@@ -44,195 +41,13 @@ pub(crate) fn instantiate_static_filter(
4441 // 8-byte integer types
4542 DataType :: Int64 => Ok ( Arc :: new ( PrimitiveFilter :: < Int64Type > :: try_new ( & in_array) ?) ) ,
4643 DataType :: UInt64 => Ok ( Arc :: new ( PrimitiveFilter :: < UInt64Type > :: try_new ( & in_array) ?) ) ,
47- // Float primitive types (use ordered wrappers for Hash/Eq )
48- DataType :: Float32 => Ok ( Arc :: new ( Float32StaticFilter :: try_new ( & in_array) ? ) ) ,
49- DataType :: Float64 => Ok ( Arc :: new ( Float64StaticFilter :: try_new ( & in_array) ? ) ) ,
44+ // Float types: reinterpret as unsigned integers (same bit pattern = equal )
45+ DataType :: Float32 => make_primitive_filter :: < UInt32Type > ( & in_array) ,
46+ DataType :: Float64 => make_primitive_filter :: < UInt64Type > ( & in_array) ,
5047 _ => {
5148 /* fall through to generic implementation for unsupported types (Struct, etc.) */
5249 Ok ( Arc :: new ( ArrayStaticFilter :: try_new ( in_array) ?) )
5350 }
5451 }
5552}
5653
57- /// Wrapper for f32 that implements Hash and Eq using bit comparison.
58- /// This treats NaN values as equal to each other when they have the same bit pattern.
59- #[ derive( Clone , Copy ) ]
60- struct OrderedFloat32 ( f32 ) ;
61-
62- impl Hash for OrderedFloat32 {
63- fn hash < H : Hasher > ( & self , state : & mut H ) {
64- self . 0 . to_ne_bytes ( ) . hash ( state) ;
65- }
66- }
67-
68- impl PartialEq for OrderedFloat32 {
69- fn eq ( & self , other : & Self ) -> bool {
70- self . 0 . to_bits ( ) == other. 0 . to_bits ( )
71- }
72- }
73-
74- impl Eq for OrderedFloat32 { }
75-
76- impl From < f32 > for OrderedFloat32 {
77- fn from ( v : f32 ) -> Self {
78- Self ( v)
79- }
80- }
81-
82- /// Wrapper for f64 that implements Hash and Eq using bit comparison.
83- /// This treats NaN values as equal to each other when they have the same bit pattern.
84- #[ derive( Clone , Copy ) ]
85- struct OrderedFloat64 ( f64 ) ;
86-
87- impl Hash for OrderedFloat64 {
88- fn hash < H : Hasher > ( & self , state : & mut H ) {
89- self . 0 . to_ne_bytes ( ) . hash ( state) ;
90- }
91- }
92-
93- impl PartialEq for OrderedFloat64 {
94- fn eq ( & self , other : & Self ) -> bool {
95- self . 0 . to_bits ( ) == other. 0 . to_bits ( )
96- }
97- }
98-
99- impl Eq for OrderedFloat64 { }
100-
101- impl From < f64 > for OrderedFloat64 {
102- fn from ( v : f64 ) -> Self {
103- Self ( v)
104- }
105- }
106-
107-
108- // Macro to generate specialized StaticFilter implementations for float types
109- // Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics
110- macro_rules! float_static_filter {
111- ( $Name: ident, $ArrowType: ty, $OrderedType: ty) => {
112- struct $Name {
113- null_count: usize ,
114- values: HashSet <$OrderedType>,
115- }
116-
117- impl $Name {
118- fn try_new( in_array: & ArrayRef ) -> Result <Self > {
119- let in_array = in_array
120- . as_primitive_opt:: <$ArrowType>( )
121- . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
122-
123- let mut values = HashSet :: with_capacity( in_array. len( ) ) ;
124- let null_count = in_array. null_count( ) ;
125-
126- for v in in_array. iter( ) . flatten( ) {
127- values. insert( <$OrderedType>:: from( v) ) ;
128- }
129-
130- Ok ( Self { null_count, values } )
131- }
132- }
133-
134- impl StaticFilter for $Name {
135- fn null_count( & self ) -> usize {
136- self . null_count
137- }
138-
139- fn contains( & self , v: & dyn Array , negated: bool ) -> Result <BooleanArray > {
140- // Handle dictionary arrays by recursing on the values
141- downcast_dictionary_array! {
142- v => {
143- let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
144- let result = take( & values_contains, v. keys( ) , None ) ?;
145- return Ok ( downcast_array( result. as_ref( ) ) )
146- }
147- _ => { }
148- }
149-
150- let v = v
151- . as_primitive_opt:: <$ArrowType>( )
152- . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
153-
154- let haystack_has_nulls = self . null_count > 0 ;
155-
156- let needle_values = v. values( ) ;
157- let needle_nulls = v. nulls( ) ;
158- let needle_has_nulls = v. null_count( ) > 0 ;
159-
160- // Truth table for `value [NOT] IN (set)` with SQL three-valued logic:
161- // ("-" means the value doesn't affect the result)
162- //
163- // | needle_null | haystack_null | negated | in set? | result |
164- // |-------------|---------------|---------|---------|--------|
165- // | true | - | false | - | null |
166- // | true | - | true | - | null |
167- // | false | true | false | yes | true |
168- // | false | true | false | no | null |
169- // | false | true | true | yes | false |
170- // | false | true | true | no | null |
171- // | false | false | false | yes | true |
172- // | false | false | false | no | false |
173- // | false | false | true | yes | false |
174- // | false | false | true | no | true |
175-
176- // Compute the "contains" result using collect_bool (fast batched approach)
177- // This ignores nulls - we handle them separately
178- let contains_buffer = if negated {
179- BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
180- !self . values. contains( & <$OrderedType>:: from( needle_values[ i] ) )
181- } )
182- } else {
183- BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
184- self . values. contains( & <$OrderedType>:: from( needle_values[ i] ) )
185- } )
186- } ;
187-
188- // Compute the null mask
189- // Output is null when:
190- // 1. needle value is null, OR
191- // 2. needle value is not in set AND haystack has nulls
192- let result_nulls = match ( needle_has_nulls, haystack_has_nulls) {
193- ( false , false ) => {
194- // No nulls anywhere
195- None
196- }
197- ( true , false ) => {
198- // Only needle has nulls - just use needle's null mask
199- needle_nulls. cloned( )
200- }
201- ( false , true ) => {
202- // Only haystack has nulls - result is null when value not in set
203- // Valid (not null) when original "in set" is true
204- // For NOT IN: contains_buffer = !original, so validity = !contains_buffer
205- let validity = if negated {
206- !& contains_buffer
207- } else {
208- contains_buffer. clone( )
209- } ;
210- Some ( NullBuffer :: new( validity) )
211- }
212- ( true , true ) => {
213- // Both have nulls - combine needle nulls with haystack-induced nulls
214- let needle_validity = needle_nulls. map( |n| n. inner( ) . clone( ) )
215- . unwrap_or_else( || BooleanBuffer :: new_set( needle_values. len( ) ) ) ;
216-
217- // Valid when original "in set" is true (see above)
218- let haystack_validity = if negated {
219- !& contains_buffer
220- } else {
221- contains_buffer. clone( )
222- } ;
223-
224- // Combined validity: valid only where both are valid
225- let combined_validity = & needle_validity & & haystack_validity;
226- Some ( NullBuffer :: new( combined_validity) )
227- }
228- } ;
229-
230- Ok ( BooleanArray :: new( contains_buffer, result_nulls) )
231- }
232- }
233- } ;
234- }
235-
236- // Generate specialized filters for float types using ordered wrappers
237- float_static_filter ! ( Float32StaticFilter , Float32Type , OrderedFloat32 ) ;
238- float_static_filter ! ( Float64StaticFilter , Float64Type , OrderedFloat64 ) ;
0 commit comments