Skip to content

Commit e6a8d57

Browse files
perf(in_list): add float support via type reinterpretation
Extends type reinterpretation to handle floats by treating their bit patterns as unsigned integers. For equality comparison, only the bit pattern matters, so Float64 can be reinterpreted as UInt64. Adds ReinterpretedPrimitive<D> filter that wraps PrimitiveFilter<D> and reinterprets input arrays at query time. The strategy module now routes Float32/Float64 through make_primitive_filter::<UInt32/UInt64>. This eliminates the need for OrderedFloat wrappers and their associated overhead, while maintaining correctness for NaN handling.
1 parent 88ad0f8 commit e6a8d57

1 file changed

Lines changed: 6 additions & 191 deletions

File tree

  • datafusion/physical-expr/src/expressions/in_list

datafusion/physical-expr/src/expressions/in_list/strategy.rs

Lines changed: 6 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,15 @@
1717

1818
//! Filter strategy selection for InList expressions
1919
20-
use std::hash::{Hash, Hasher};
2120
use std::sync::Arc;
2221

23-
use arrow::array::*;
24-
use arrow::buffer::{BooleanBuffer, NullBuffer};
25-
use arrow::compute::take;
22+
use arrow::array::ArrayRef;
2623
use arrow::datatypes::*;
27-
use datafusion_common::{HashSet, Result, exec_datafusion_err};
24+
use datafusion_common::Result;
2825

2926
use super::array_filter::{ArrayStaticFilter, StaticFilter};
3027
use super::primitive::{PrimitiveFilter, U8Config, U16Config};
31-
use super::transform::make_bitmap_filter;
28+
use super::transform::{make_bitmap_filter, make_primitive_filter};
3229

3330
pub(crate) fn instantiate_static_filter(
3431
in_array: ArrayRef,
@@ -44,195 +41,13 @@ pub(crate) fn instantiate_static_filter(
4441
// 8-byte integer types
4542
DataType::Int64 => Ok(Arc::new(PrimitiveFilter::<Int64Type>::try_new(&in_array)?)),
4643
DataType::UInt64 => Ok(Arc::new(PrimitiveFilter::<UInt64Type>::try_new(&in_array)?)),
47-
// Float primitive types (use ordered wrappers for Hash/Eq)
48-
DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)),
49-
DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)),
44+
// Float types: reinterpret as unsigned integers (same bit pattern = equal)
45+
DataType::Float32 => make_primitive_filter::<UInt32Type>(&in_array),
46+
DataType::Float64 => make_primitive_filter::<UInt64Type>(&in_array),
5047
_ => {
5148
/* fall through to generic implementation for unsupported types (Struct, etc.) */
5249
Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?))
5350
}
5451
}
5552
}
5653

57-
/// Wrapper for f32 that implements Hash and Eq using bit comparison.
58-
/// This treats NaN values as equal to each other when they have the same bit pattern.
59-
#[derive(Clone, Copy)]
60-
struct OrderedFloat32(f32);
61-
62-
impl Hash for OrderedFloat32 {
63-
fn hash<H: Hasher>(&self, state: &mut H) {
64-
self.0.to_ne_bytes().hash(state);
65-
}
66-
}
67-
68-
impl PartialEq for OrderedFloat32 {
69-
fn eq(&self, other: &Self) -> bool {
70-
self.0.to_bits() == other.0.to_bits()
71-
}
72-
}
73-
74-
impl Eq for OrderedFloat32 {}
75-
76-
impl From<f32> for OrderedFloat32 {
77-
fn from(v: f32) -> Self {
78-
Self(v)
79-
}
80-
}
81-
82-
/// Wrapper for f64 that implements Hash and Eq using bit comparison.
83-
/// This treats NaN values as equal to each other when they have the same bit pattern.
84-
#[derive(Clone, Copy)]
85-
struct OrderedFloat64(f64);
86-
87-
impl Hash for OrderedFloat64 {
88-
fn hash<H: Hasher>(&self, state: &mut H) {
89-
self.0.to_ne_bytes().hash(state);
90-
}
91-
}
92-
93-
impl PartialEq for OrderedFloat64 {
94-
fn eq(&self, other: &Self) -> bool {
95-
self.0.to_bits() == other.0.to_bits()
96-
}
97-
}
98-
99-
impl Eq for OrderedFloat64 {}
100-
101-
impl From<f64> for OrderedFloat64 {
102-
fn from(v: f64) -> Self {
103-
Self(v)
104-
}
105-
}
106-
107-
108-
// Macro to generate specialized StaticFilter implementations for float types
109-
// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics
110-
macro_rules! float_static_filter {
111-
($Name:ident, $ArrowType:ty, $OrderedType:ty) => {
112-
struct $Name {
113-
null_count: usize,
114-
values: HashSet<$OrderedType>,
115-
}
116-
117-
impl $Name {
118-
fn try_new(in_array: &ArrayRef) -> Result<Self> {
119-
let in_array = in_array
120-
.as_primitive_opt::<$ArrowType>()
121-
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
122-
123-
let mut values = HashSet::with_capacity(in_array.len());
124-
let null_count = in_array.null_count();
125-
126-
for v in in_array.iter().flatten() {
127-
values.insert(<$OrderedType>::from(v));
128-
}
129-
130-
Ok(Self { null_count, values })
131-
}
132-
}
133-
134-
impl StaticFilter for $Name {
135-
fn null_count(&self) -> usize {
136-
self.null_count
137-
}
138-
139-
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
140-
// Handle dictionary arrays by recursing on the values
141-
downcast_dictionary_array! {
142-
v => {
143-
let values_contains = self.contains(v.values().as_ref(), negated)?;
144-
let result = take(&values_contains, v.keys(), None)?;
145-
return Ok(downcast_array(result.as_ref()))
146-
}
147-
_ => {}
148-
}
149-
150-
let v = v
151-
.as_primitive_opt::<$ArrowType>()
152-
.ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?;
153-
154-
let haystack_has_nulls = self.null_count > 0;
155-
156-
let needle_values = v.values();
157-
let needle_nulls = v.nulls();
158-
let needle_has_nulls = v.null_count() > 0;
159-
160-
// Truth table for `value [NOT] IN (set)` with SQL three-valued logic:
161-
// ("-" means the value doesn't affect the result)
162-
//
163-
// | needle_null | haystack_null | negated | in set? | result |
164-
// |-------------|---------------|---------|---------|--------|
165-
// | true | - | false | - | null |
166-
// | true | - | true | - | null |
167-
// | false | true | false | yes | true |
168-
// | false | true | false | no | null |
169-
// | false | true | true | yes | false |
170-
// | false | true | true | no | null |
171-
// | false | false | false | yes | true |
172-
// | false | false | false | no | false |
173-
// | false | false | true | yes | false |
174-
// | false | false | true | no | true |
175-
176-
// Compute the "contains" result using collect_bool (fast batched approach)
177-
// This ignores nulls - we handle them separately
178-
let contains_buffer = if negated {
179-
BooleanBuffer::collect_bool(needle_values.len(), |i| {
180-
!self.values.contains(&<$OrderedType>::from(needle_values[i]))
181-
})
182-
} else {
183-
BooleanBuffer::collect_bool(needle_values.len(), |i| {
184-
self.values.contains(&<$OrderedType>::from(needle_values[i]))
185-
})
186-
};
187-
188-
// Compute the null mask
189-
// Output is null when:
190-
// 1. needle value is null, OR
191-
// 2. needle value is not in set AND haystack has nulls
192-
let result_nulls = match (needle_has_nulls, haystack_has_nulls) {
193-
(false, false) => {
194-
// No nulls anywhere
195-
None
196-
}
197-
(true, false) => {
198-
// Only needle has nulls - just use needle's null mask
199-
needle_nulls.cloned()
200-
}
201-
(false, true) => {
202-
// Only haystack has nulls - result is null when value not in set
203-
// Valid (not null) when original "in set" is true
204-
// For NOT IN: contains_buffer = !original, so validity = !contains_buffer
205-
let validity = if negated {
206-
!&contains_buffer
207-
} else {
208-
contains_buffer.clone()
209-
};
210-
Some(NullBuffer::new(validity))
211-
}
212-
(true, true) => {
213-
// Both have nulls - combine needle nulls with haystack-induced nulls
214-
let needle_validity = needle_nulls.map(|n| n.inner().clone())
215-
.unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len()));
216-
217-
// Valid when original "in set" is true (see above)
218-
let haystack_validity = if negated {
219-
!&contains_buffer
220-
} else {
221-
contains_buffer.clone()
222-
};
223-
224-
// Combined validity: valid only where both are valid
225-
let combined_validity = &needle_validity & &haystack_validity;
226-
Some(NullBuffer::new(combined_validity))
227-
}
228-
};
229-
230-
Ok(BooleanArray::new(contains_buffer, result_nulls))
231-
}
232-
}
233-
};
234-
}
235-
236-
// Generate specialized filters for float types using ordered wrappers
237-
float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32);
238-
float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64);

0 commit comments

Comments
 (0)