Skip to content

Commit 094df73

Browse files
refactor(in_list): extract filter dispatch to strategy module
Extract the filter instantiation logic and specialized filter implementations to a dedicated strategy module. This is a pure refactoring with no behavioral changes. Moves: - instantiate_static_filter function - OrderedFloat32/OrderedFloat64 wrappers - primitive_static_filter! macro (Int8..UInt64 filters) - float_static_filter! macro (Float32/Float64 filters)
1 parent e4df471 commit 094df73

3 files changed

Lines changed: 529 additions & 345 deletions

File tree

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 5 additions & 345 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//! Implementation of `InList` expressions: [`InListExpr`]
1919
2020
mod array_filter;
21+
mod strategy;
2122

2223
use std::any::Any;
2324
use std::fmt::Debug;
@@ -30,15 +31,15 @@ use crate::physical_expr::physical_exprs_bag_equal;
3031
use arrow::array::*;
3132
use arrow::buffer::{BooleanBuffer, NullBuffer};
3233
use arrow::compute::kernels::boolean::{not, or_kleene};
33-
use arrow::compute::{take, SortOptions};
34+
use arrow::compute::SortOptions;
3435
use arrow::datatypes::*;
3536
use datafusion_common::{
36-
DFSchema, HashSet, Result, ScalarValue, assert_or_internal_err, exec_datafusion_err,
37-
exec_err,
37+
DFSchema, Result, ScalarValue, assert_or_internal_err, exec_err,
3838
};
3939
use datafusion_expr::{ColumnarValue, expr_vec_fmt};
4040

41-
use array_filter::{ArrayStaticFilter, StaticFilter};
41+
use array_filter::StaticFilter;
42+
use strategy::instantiate_static_filter;
4243

4344
/// InList
4445
pub struct InListExpr {
@@ -58,347 +59,6 @@ impl Debug for InListExpr {
5859
}
5960
}
6061

61-
fn instantiate_static_filter(
62-
in_array: ArrayRef,
63-
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
64-
match in_array.data_type() {
65-
// Integer primitive types
66-
DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)),
67-
DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)),
68-
DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)),
69-
DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)),
70-
DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)),
71-
DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)),
72-
DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)),
73-
DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)),
74-
// Float primitive types (use ordered wrappers for Hash/Eq)
75-
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
76-
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
77-
_ => {
78-
/* fall through to generic implementation for unsupported types (Struct, etc.) */
79-
Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?))
80-
}
81-
}
82-
}
83-
84-
/// Wrapper for f32 that implements Hash and Eq using bit comparison.
85-
/// This treats NaN values as equal to each other when they have the same bit pattern.
86-
#[derive(Clone, Copy)]
87-
struct OrderedFloat32(f32);
88-
89-
impl Hash for OrderedFloat32 {
90-
fn hash<H: Hasher>(&self, state: &mut H) {
91-
self.0.to_ne_bytes().hash(state);
92-
}
93-
}
94-
95-
impl PartialEq for OrderedFloat32 {
96-
fn eq(&self, other: &Self) -> bool {
97-
self.0.to_bits() == other.0.to_bits()
98-
}
99-
}
100-
101-
impl Eq for OrderedFloat32 {}
102-
103-
impl From<f32> for OrderedFloat32 {
104-
fn from(v: f32) -> Self {
105-
Self(v)
106-
}
107-
}
108-
109-
/// Wrapper for f64 that implements Hash and Eq using bit comparison.
110-
/// This treats NaN values as equal to each other when they have the same bit pattern.
111-
#[derive(Clone, Copy)]
112-
struct OrderedFloat64(f64);
113-
114-
impl Hash for OrderedFloat64 {
115-
fn hash<H: Hasher>(&self, state: &mut H) {
116-
self.0.to_ne_bytes().hash(state);
117-
}
118-
}
119-
120-
impl PartialEq for OrderedFloat64 {
121-
fn eq(&self, other: &Self) -> bool {
122-
self.0.to_bits() == other.0.to_bits()
123-
}
124-
}
125-
126-
impl Eq for OrderedFloat64 {}
127-
128-
impl From<f64> for OrderedFloat64 {
129-
fn from(v: f64) -> Self {
130-
Self(v)
131-
}
132-
}
133-
134-
// Macro to generate specialized StaticFilter implementations for primitive types
135-
macro_rules! primitive_static_filter {
136-
($Name:ident, $ArrowType:ty) => {
137-
struct $Name {
138-
null_count: usize,
139-
values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>,
140-
}
141-
142-
impl $Name {
143-
fn try_new(in_array: &ArrayRef) -> Result<Self> {
144-
let in_array = in_array
145-
.as_primitive_opt::<$ArrowType>()
146-
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
147-
148-
let mut values = HashSet::with_capacity(in_array.len());
149-
let null_count = in_array.null_count();
150-
151-
for v in in_array.iter().flatten() {
152-
values.insert(v);
153-
}
154-
155-
Ok(Self { null_count, values })
156-
}
157-
}
158-
159-
impl StaticFilter for $Name {
160-
fn null_count(&self) -> usize {
161-
self.null_count
162-
}
163-
164-
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
165-
// Handle dictionary arrays by recursing on the values
166-
downcast_dictionary_array! {
167-
v => {
168-
let values_contains = self.contains(v.values().as_ref(), negated)?;
169-
let result = take(&values_contains, v.keys(), None)?;
170-
return Ok(downcast_array(result.as_ref()))
171-
}
172-
_ => {}
173-
}
174-
175-
let v = v
176-
.as_primitive_opt::<$ArrowType>()
177-
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
178-
179-
let haystack_has_nulls = self.null_count > 0;
180-
181-
let needle_values = v.values();
182-
let needle_nulls = v.nulls();
183-
let needle_has_nulls = v.null_count() > 0;
184-
185-
// Truth table for `value [NOT] IN (set)` with SQL three-valued logic:
186-
// ("-" means the value doesn't affect the result)
187-
//
188-
// | needle_null | haystack_null | negated | in set? | result |
189-
// |-------------|---------------|---------|---------|--------|
190-
// | true | - | false | - | null |
191-
// | true | - | true | - | null |
192-
// | false | true | false | yes | true |
193-
// | false | true | false | no | null |
194-
// | false | true | true | yes | false |
195-
// | false | true | true | no | null |
196-
// | false | false | false | yes | true |
197-
// | false | false | false | no | false |
198-
// | false | false | true | yes | false |
199-
// | false | false | true | no | true |
200-
201-
// Compute the "contains" result using collect_bool (fast batched approach)
202-
// This ignores nulls - we handle them separately
203-
let contains_buffer = if negated {
204-
BooleanBuffer::collect_bool(needle_values.len(), |i| {
205-
!self.values.contains(&needle_values[i])
206-
})
207-
} else {
208-
BooleanBuffer::collect_bool(needle_values.len(), |i| {
209-
self.values.contains(&needle_values[i])
210-
})
211-
};
212-
213-
// Compute the null mask
214-
// Output is null when:
215-
// 1. needle value is null, OR
216-
// 2. needle value is not in set AND haystack has nulls
217-
let result_nulls = match (needle_has_nulls, haystack_has_nulls) {
218-
(false, false) => {
219-
// No nulls anywhere
220-
None
221-
}
222-
(true, false) => {
223-
// Only needle has nulls - just use needle's null mask
224-
needle_nulls.cloned()
225-
}
226-
(false, true) => {
227-
// Only haystack has nulls - result is null when value not in set
228-
// Valid (not null) when original "in set" is true
229-
// For NOT IN: contains_buffer = !original, so validity = !contains_buffer
230-
let validity = if negated {
231-
!&contains_buffer
232-
} else {
233-
contains_buffer.clone()
234-
};
235-
Some(NullBuffer::new(validity))
236-
}
237-
(true, true) => {
238-
// Both have nulls - combine needle nulls with haystack-induced nulls
239-
let needle_validity = needle_nulls.map(|n| n.inner().clone())
240-
.unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len()));
241-
242-
// Valid when original "in set" is true (see above)
243-
let haystack_validity = if negated {
244-
!&contains_buffer
245-
} else {
246-
contains_buffer.clone()
247-
};
248-
249-
// Combined validity: valid only where both are valid
250-
let combined_validity = &needle_validity & &haystack_validity;
251-
Some(NullBuffer::new(combined_validity))
252-
}
253-
};
254-
255-
Ok(BooleanArray::new(contains_buffer, result_nulls))
256-
}
257-
}
258-
};
259-
}
260-
261-
// Generate specialized filters for all integer primitive types
262-
primitive_static_filter!(Int8StaticFilter, Int8Type);
263-
primitive_static_filter!(Int16StaticFilter, Int16Type);
264-
primitive_static_filter!(Int32StaticFilter, Int32Type);
265-
primitive_static_filter!(Int64StaticFilter, Int64Type);
266-
primitive_static_filter!(UInt8StaticFilter, UInt8Type);
267-
primitive_static_filter!(UInt16StaticFilter, UInt16Type);
268-
primitive_static_filter!(UInt32StaticFilter, UInt32Type);
269-
primitive_static_filter!(UInt64StaticFilter, UInt64Type);
270-
271-
// Macro to generate specialized StaticFilter implementations for float types
272-
// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics
273-
macro_rules! float_static_filter {
274-
($Name:ident, $ArrowType:ty, $OrderedType:ty) => {
275-
struct $Name {
276-
null_count: usize,
277-
values: HashSet<$OrderedType>,
278-
}
279-
280-
impl $Name {
281-
fn try_new(in_array: &ArrayRef) -> Result<Self> {
282-
let in_array = in_array
283-
.as_primitive_opt::<$ArrowType>()
284-
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
285-
286-
let mut values = HashSet::with_capacity(in_array.len());
287-
let null_count = in_array.null_count();
288-
289-
for v in in_array.iter().flatten() {
290-
values.insert(<$OrderedType>::from(v));
291-
}
292-
293-
Ok(Self { null_count, values })
294-
}
295-
}
296-
297-
impl StaticFilter for $Name {
298-
fn null_count(&self) -> usize {
299-
self.null_count
300-
}
301-
302-
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
303-
// Handle dictionary arrays by recursing on the values
304-
downcast_dictionary_array! {
305-
v => {
306-
let values_contains = self.contains(v.values().as_ref(), negated)?;
307-
let result = take(&values_contains, v.keys(), None)?;
308-
return Ok(downcast_array(result.as_ref()))
309-
}
310-
_ => {}
311-
}
312-
313-
let v = v
314-
.as_primitive_opt::<$ArrowType>()
315-
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
316-
317-
let haystack_has_nulls = self.null_count > 0;
318-
319-
let needle_values = v.values();
320-
let needle_nulls = v.nulls();
321-
let needle_has_nulls = v.null_count() > 0;
322-
323-
// Truth table for `value [NOT] IN (set)` with SQL three-valued logic:
324-
// ("-" means the value doesn't affect the result)
325-
//
326-
// | needle_null | haystack_null | negated | in set? | result |
327-
// |-------------|---------------|---------|---------|--------|
328-
// | true | - | false | - | null |
329-
// | true | - | true | - | null |
330-
// | false | true | false | yes | true |
331-
// | false | true | false | no | null |
332-
// | false | true | true | yes | false |
333-
// | false | true | true | no | null |
334-
// | false | false | false | yes | true |
335-
// | false | false | false | no | false |
336-
// | false | false | true | yes | false |
337-
// | false | false | true | no | true |
338-
339-
// Compute the "contains" result using collect_bool (fast batched approach)
340-
// This ignores nulls - we handle them separately
341-
let contains_buffer = if negated {
342-
BooleanBuffer::collect_bool(needle_values.len(), |i| {
343-
!self.values.contains(&<$OrderedType>::from(needle_values[i]))
344-
})
345-
} else {
346-
BooleanBuffer::collect_bool(needle_values.len(), |i| {
347-
self.values.contains(&<$OrderedType>::from(needle_values[i]))
348-
})
349-
};
350-
351-
// Compute the null mask
352-
// Output is null when:
353-
// 1. needle value is null, OR
354-
// 2. needle value is not in set AND haystack has nulls
355-
let result_nulls = match (needle_has_nulls, haystack_has_nulls) {
356-
(false, false) => {
357-
// No nulls anywhere
358-
None
359-
}
360-
(true, false) => {
361-
// Only needle has nulls - just use needle's null mask
362-
needle_nulls.cloned()
363-
}
364-
(false, true) => {
365-
// Only haystack has nulls - result is null when value not in set
366-
// Valid (not null) when original "in set" is true
367-
// For NOT IN: contains_buffer = !original, so validity = !contains_buffer
368-
let validity = if negated {
369-
!&contains_buffer
370-
} else {
371-
contains_buffer.clone()
372-
};
373-
Some(NullBuffer::new(validity))
374-
}
375-
(true, true) => {
376-
// Both have nulls - combine needle nulls with haystack-induced nulls
377-
let needle_validity = needle_nulls.map(|n| n.inner().clone())
378-
.unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len()));
379-
380-
// Valid when original "in set" is true (see above)
381-
let haystack_validity = if negated {
382-
!&contains_buffer
383-
} else {
384-
contains_buffer.clone()
385-
};
386-
387-
// Combined validity: valid only where both are valid
388-
let combined_validity = &needle_validity & &haystack_validity;
389-
Some(NullBuffer::new(combined_validity))
390-
}
391-
};
392-
393-
Ok(BooleanArray::new(contains_buffer, result_nulls))
394-
}
395-
}
396-
};
397-
}
398-
399-
// Generate specialized filters for float types using ordered wrappers
400-
float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32);
401-
float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64);
40262

40363
/// Evaluates the list of expressions into an array, flattening any dictionaries
40464
fn evaluate_list(

0 commit comments

Comments
 (0)