@@ -27,7 +27,7 @@ use arrow::datatypes::*;
2727use datafusion_common:: { HashSet , Result , exec_datafusion_err} ;
2828
2929use super :: array_filter:: { ArrayStaticFilter , StaticFilter } ;
30- use super :: primitive:: { U8Config , U16Config } ;
30+ use super :: primitive:: { PrimitiveFilter , U8Config , U16Config } ;
3131use super :: transform:: make_bitmap_filter;
3232
3333pub ( crate ) fn instantiate_static_filter (
@@ -39,11 +39,11 @@ pub(crate) fn instantiate_static_filter(
3939 // 2-byte types: use bitmap (65536 bits = 8 KB)
4040 DataType :: Int16 | DataType :: UInt16 => make_bitmap_filter :: < U16Config > ( & in_array) ,
4141 // 4-byte integer types
42- DataType :: Int32 => Ok ( Arc :: new ( Int32StaticFilter :: try_new ( & in_array) ?) ) ,
43- DataType :: UInt32 => Ok ( Arc :: new ( UInt32StaticFilter :: try_new ( & in_array) ?) ) ,
42+ DataType :: Int32 => Ok ( Arc :: new ( PrimitiveFilter :: < Int32Type > :: try_new ( & in_array) ?) ) ,
43+ DataType :: UInt32 => Ok ( Arc :: new ( PrimitiveFilter :: < UInt32Type > :: try_new ( & in_array) ?) ) ,
4444 // 8-byte integer types
45- DataType :: Int64 => Ok ( Arc :: new ( Int64StaticFilter :: try_new ( & in_array) ?) ) ,
46- DataType :: UInt64 => Ok ( Arc :: new ( UInt64StaticFilter :: try_new ( & in_array) ?) ) ,
45+ DataType :: Int64 => Ok ( Arc :: new ( PrimitiveFilter :: < Int64Type > :: try_new ( & in_array) ?) ) ,
46+ DataType :: UInt64 => Ok ( Arc :: new ( PrimitiveFilter :: < UInt64Type > :: try_new ( & in_array) ?) ) ,
4747 // Float primitive types (use ordered wrappers for Hash/Eq)
4848 DataType :: Float32 => Ok ( Arc :: new ( Float32StaticFilter :: try_new ( & in_array) ?) ) ,
4949 DataType :: Float64 => Ok ( Arc :: new ( Float64StaticFilter :: try_new ( & in_array) ?) ) ,
@@ -104,139 +104,6 @@ impl From<f64> for OrderedFloat64 {
104104 }
105105}
106106
107- // Macro to generate specialized StaticFilter implementations for primitive types
108- macro_rules! primitive_static_filter {
109- ( $Name: ident, $ArrowType: ty) => {
110- struct $Name {
111- null_count: usize ,
112- values: HashSet <<$ArrowType as ArrowPrimitiveType >:: Native >,
113- }
114-
115- impl $Name {
116- fn try_new( in_array: & ArrayRef ) -> Result <Self > {
117- let in_array = in_array
118- . as_primitive_opt:: <$ArrowType>( )
119- . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
120-
121- let mut values = HashSet :: with_capacity( in_array. len( ) ) ;
122- let null_count = in_array. null_count( ) ;
123-
124- for v in in_array. iter( ) . flatten( ) {
125- values. insert( v) ;
126- }
127-
128- Ok ( Self { null_count, values } )
129- }
130- }
131-
132- impl StaticFilter for $Name {
133- fn null_count( & self ) -> usize {
134- self . null_count
135- }
136-
137- fn contains( & self , v: & dyn Array , negated: bool ) -> Result <BooleanArray > {
138- // Handle dictionary arrays by recursing on the values
139- downcast_dictionary_array! {
140- v => {
141- let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
142- let result = take( & values_contains, v. keys( ) , None ) ?;
143- return Ok ( downcast_array( result. as_ref( ) ) )
144- }
145- _ => { }
146- }
147-
148- let v = v
149- . as_primitive_opt:: <$ArrowType>( )
150- . ok_or_else( || exec_datafusion_err!( "Failed to downcast an array to a '{}' array" , stringify!( $ArrowType) ) ) ?;
151-
152- let haystack_has_nulls = self . null_count > 0 ;
153-
154- let needle_values = v. values( ) ;
155- let needle_nulls = v. nulls( ) ;
156- let needle_has_nulls = v. null_count( ) > 0 ;
157-
158- // Truth table for `value [NOT] IN (set)` with SQL three-valued logic:
159- // ("-" means the value doesn't affect the result)
160- //
161- // | needle_null | haystack_null | negated | in set? | result |
162- // |-------------|---------------|---------|---------|--------|
163- // | true | - | false | - | null |
164- // | true | - | true | - | null |
165- // | false | true | false | yes | true |
166- // | false | true | false | no | null |
167- // | false | true | true | yes | false |
168- // | false | true | true | no | null |
169- // | false | false | false | yes | true |
170- // | false | false | false | no | false |
171- // | false | false | true | yes | false |
172- // | false | false | true | no | true |
173-
174- // Compute the "contains" result using collect_bool (fast batched approach)
175- // This ignores nulls - we handle them separately
176- let contains_buffer = if negated {
177- BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
178- !self . values. contains( & needle_values[ i] )
179- } )
180- } else {
181- BooleanBuffer :: collect_bool( needle_values. len( ) , |i| {
182- self . values. contains( & needle_values[ i] )
183- } )
184- } ;
185-
186- // Compute the null mask
187- // Output is null when:
188- // 1. needle value is null, OR
189- // 2. needle value is not in set AND haystack has nulls
190- let result_nulls = match ( needle_has_nulls, haystack_has_nulls) {
191- ( false , false ) => {
192- // No nulls anywhere
193- None
194- }
195- ( true , false ) => {
196- // Only needle has nulls - just use needle's null mask
197- needle_nulls. cloned( )
198- }
199- ( false , true ) => {
200- // Only haystack has nulls - result is null when value not in set
201- // Valid (not null) when original "in set" is true
202- // For NOT IN: contains_buffer = !original, so validity = !contains_buffer
203- let validity = if negated {
204- !& contains_buffer
205- } else {
206- contains_buffer. clone( )
207- } ;
208- Some ( NullBuffer :: new( validity) )
209- }
210- ( true , true ) => {
211- // Both have nulls - combine needle nulls with haystack-induced nulls
212- let needle_validity = needle_nulls. map( |n| n. inner( ) . clone( ) )
213- . unwrap_or_else( || BooleanBuffer :: new_set( needle_values. len( ) ) ) ;
214-
215- // Valid when original "in set" is true (see above)
216- let haystack_validity = if negated {
217- !& contains_buffer
218- } else {
219- contains_buffer. clone( )
220- } ;
221-
222- // Combined validity: valid only where both are valid
223- let combined_validity = & needle_validity & & haystack_validity;
224- Some ( NullBuffer :: new( combined_validity) )
225- }
226- } ;
227-
228- Ok ( BooleanArray :: new( contains_buffer, result_nulls) )
229- }
230- }
231- } ;
232- }
233-
234- // Generate specialized filters for 4-byte and 8-byte integer primitive types
235- // (1-byte and 2-byte types use BitmapFilter instead)
236- primitive_static_filter ! ( Int32StaticFilter , Int32Type ) ;
237- primitive_static_filter ! ( Int64StaticFilter , Int64Type ) ;
238- primitive_static_filter ! ( UInt32StaticFilter , UInt32Type ) ;
239- primitive_static_filter ! ( UInt64StaticFilter , UInt64Type ) ;
240107
241108// Macro to generate specialized StaticFilter implementations for float types
242109// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics
0 commit comments