Skip to content

Commit 4b34f4a

Browse files
Implement Legacy String Optimization (Utf8TwoStageFilter)
Port of the two-stage View optimization to standard Utf8 and LargeUtf8 types. Encodes strings as i128 (len + prefix) for fast O(1) pre-filtering before falling back to full string comparison. Triggers for Utf8 and LargeUtf8.
1 parent f7d6008 commit 4b34f4a

3 files changed

Lines changed: 310 additions & 4 deletions

File tree

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ use arrow::compute::SortOptions;
3030
use arrow::compute::kernels::boolean::{not, or_kleene};
3131
use arrow::compute::kernels::cmp::eq as arrow_eq;
3232
use arrow::datatypes::*;
33-
3433
use datafusion_common::{
3534
DFSchema, Result, ScalarValue, assert_or_internal_err, exec_err,
3635
};

datafusion/physical-expr/src/expressions/in_list/strategy.rs

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
//! Filter selection strategy for InList expressions
19+
//!
20+
//! Selects the optimal lookup strategy based on data type and list size:
21+
//!
22+
//! - 1-byte types (Int8/UInt8): bitmap (32 bytes, O(1) bit test)
23+
//! - 2-byte types (Int16/UInt16): bitmap (8 KB, O(1) bit test)
24+
//! - 4-byte types (Int32/Float32): branchless (≤32) or hash (>32)
25+
//! - 8-byte types (Int64/Float64): branchless (≤16) or hash (>16)
26+
//! - 16-byte types (Decimal128): branchless (≤4) or hash (>4)
27+
//! - Utf8View (short strings): branchless (≤4) or hash (>4)
28+
//! - Byte arrays (Utf8, Binary, etc.): ByteArrayFilter / ByteViewFilter
29+
//! - Other types: ArrayStaticFilter (fallback for List, Struct, Map, etc.)
30+
1831
use std::sync::Arc;
1932

2033
use arrow::array::ArrayRef;
@@ -27,13 +40,25 @@ use super::result::handle_dictionary;
2740
use super::static_filter::StaticFilter;
2841
use super::transform::{
2942
make_bitmap_filter, make_branchless_filter, make_byte_view_masked_filter,
30-
make_utf8view_branchless_filter, make_utf8view_hash_filter,
31-
reinterpret_any_primitive_to, utf8view_all_short_strings,
43+
make_utf8_two_stage_filter, make_utf8view_branchless_filter,
44+
make_utf8view_hash_filter, utf8_all_short_strings, utf8view_all_short_strings,
3245
};
3346

3447
// =============================================================================
3548
// LOOKUP STRATEGY THRESHOLDS (tuned via microbenchmarks)
3649
// =============================================================================
50+
//
51+
// Based on minimum batch time (8192 lookups per batch):
52+
// - Int8 (1 byte): BITMAP (32 bytes, always fastest)
53+
// - Int16 (2 bytes): BITMAP (8 KB, always fastest)
54+
// - Int32 (4 bytes): branchless up to 32, then hashset
55+
// - Int64 (8 bytes): branchless up to 16, then hashset
56+
// - Int128 (16 bytes): branchless up to 4, then hashset
57+
// - Byte arrays: ByteArrayFilter / ByteViewFilter
58+
// - Other types: ArrayStaticFilter (fallback for List, Struct, Map, etc.)
59+
//
60+
// NOTE: Binary search and linear scan were benchmarked but consistently
61+
// lost to the strategies above at all tested list sizes.
3762

3863
/// Maximum list size for branchless lookup on 4-byte primitives (Int32, UInt32, Float32).
3964
const BRANCHLESS_MAX_4B: usize = 32;
@@ -63,6 +88,10 @@ enum FilterStrategy {
6388
}
6489

6590
/// Determines the optimal lookup strategy based on data type and list size.
91+
///
92+
/// For 1-byte and 2-byte types, bitmap is always used (benchmarks show it's
93+
/// faster than both branchless and hashed at all list sizes).
94+
/// For larger types, cutoffs are tuned per byte-width.
6695
fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy {
6796
match dt.primitive_width() {
6897
Some(1) => FilterStrategy::Bitmap1B,
@@ -97,6 +126,9 @@ fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy {
97126
// =============================================================================
98127

99128
/// Creates the optimal static filter for the given array.
129+
///
130+
/// This is the main entry point for filter creation. It analyzes the array's
131+
/// data type and size to select the best lookup strategy.
100132
pub(super) fn instantiate_static_filter(
101133
in_array: ArrayRef,
102134
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
@@ -134,15 +166,31 @@ pub(super) fn instantiate_static_filter(
134166
exec_datafusion_err!("Hashed strategy selected but no filter for {:?}", dt)
135167
})?,
136168

169+
// Utf8/LargeUtf8: Two-stage filter when all IN-list strings are short (≤12 bytes).
170+
// Stage 1 encodes as i128 (length + first 12 bytes) for O(1) rejection.
171+
// When strings are long, the encoding can't definitively match and the
172+
// overhead regresses vs the generic fallback, so we skip it.
173+
(DataType::Utf8 | DataType::LargeUtf8, Generic)
174+
if utf8_all_short_strings(in_array.as_ref()) =>
175+
{
176+
make_utf8_two_stage_filter(in_array)
177+
}
178+
179+
// Binary variants: Use ArrayStaticFilter (make_comparator)
180+
(DataType::Binary | DataType::LargeBinary, Generic) => {
181+
Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?))
182+
}
183+
137184
// Byte view filters (Utf8View, BinaryView)
185+
// Both use two-stage filter: masked view pre-check + full verification
138186
(DataType::Utf8View, Generic) => {
139187
make_byte_view_masked_filter::<StringViewType>(in_array)
140188
}
141189
(DataType::BinaryView, Generic) => {
142190
make_byte_view_masked_filter::<BinaryViewType>(in_array)
143191
}
144192

145-
// Fallback for nested/complex types and strings.
193+
// Fallback for nested/complex types (List, Struct, Map, Union, etc.)
146194
(_, Generic) => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)),
147195
}
148196
}
@@ -155,6 +203,7 @@ fn dispatch_branchless(
155203
arr: &ArrayRef,
156204
) -> Option<Result<Arc<dyn StaticFilter + Send + Sync>>> {
157205
// Dispatch to width-specific branchless filter.
206+
// Each width has its own max size: 4B→32, 8B→16, 16B→4
158207
match arr.data_type().primitive_width() {
159208
Some(4) => Some(make_branchless_filter::<UInt32Type>(arr, 4)),
160209
Some(8) => Some(make_branchless_filter::<UInt64Type>(arr, 8)),
@@ -190,6 +239,8 @@ fn dispatch_hashed(
190239
Some(16) => Some(make_direct_probe_filter_reinterpreted::<Decimal128Type>(
191240
arr,
192241
)),
242+
// Other widths (1, 2) use Bitmap strategy and never reach here.
243+
// Unknown widths fall through to Generic strategy.
193244
_ => None,
194245
}
195246
}
@@ -202,6 +253,8 @@ where
202253
D: ArrowPrimitiveType + 'static,
203254
D::Native: Send + Sync + DirectProbeHashable + 'static,
204255
{
256+
use super::transform::reinterpret_any_primitive_to;
257+
205258
// Fast path: already the right type
206259
if in_array.data_type() == &D::DATA_TYPE {
207260
return Ok(Arc::new(DirectProbeFilter::<D>::try_new(in_array)?));

datafusion/physical-expr/src/expressions/in_list/transform.rs

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,3 +545,257 @@ where
545545
{
546546
Ok(Arc::new(ByteViewMaskedFilter::<T>::try_new(in_array)?))
547547
}
548+
549+
// =============================================================================
550+
// UTF8 TWO-STAGE FILTER (length+prefix pre-check + full verification)
551+
// =============================================================================
552+
//
553+
// Similar to ByteViewMaskedFilter but for regular Utf8/LargeUtf8 arrays.
554+
// Encodes strings as i128 with length + prefix for quick rejection.
555+
//
556+
// Encoding (Little Endian):
557+
// - Bytes 0-3: length (u32)
558+
// - Bytes 4-15: data (12 bytes)
559+
//
560+
// This naturally distinguishes short from long strings via the length field.
561+
// For short strings (≤12 bytes), the i128 contains all data → match is definitive.
562+
// For long strings (>12 bytes), a match requires full string comparison.
563+
564+
/// Encodes a string as i128 with length + prefix.
565+
/// Format: [len:u32][data:12 bytes] (Little Endian)
566+
#[inline(always)]
567+
fn encode_string_as_i128(s: &[u8]) -> i128 {
568+
let len = s.len();
569+
570+
// Optimization: Construct the i128 directly using arithmetic and pointer copy
571+
// to avoid Store-to-Load Forwarding (STLF) stalls on x64 and minimize LSU pressure on ARM.
572+
//
573+
// The layout in memory must match Utf8View: [4 bytes len][12 bytes data]
574+
let mut val: u128 = len as u128; // Length in bytes 0-3
575+
576+
// Safety: writing to the remaining bytes of an initialized u128.
577+
// We use a pointer copy for the string data as it is variable length (0-12 bytes).
578+
unsafe {
579+
let dst = (&mut val as *mut u128 as *mut u8).add(4);
580+
std::ptr::copy_nonoverlapping(s.as_ptr(), dst, len.min(INLINE_STRING_LEN));
581+
}
582+
583+
val as i128
584+
}
585+
586+
/// Two-stage filter for Utf8/LargeUtf8 arrays.
587+
///
588+
/// Stage 1: Quick rejection using length+prefix as i128
589+
/// - Non-matches rejected via O(1) DirectProbeFilter lookup
590+
/// - Short string matches (≤12 bytes) accepted immediately
591+
///
592+
/// Stage 2: Full verification for long string matches
593+
/// - Only reached when encoded i128 matches AND string length >12 bytes
594+
/// - Uses HashTable with full string comparison
595+
pub(crate) struct Utf8TwoStageFilter<O: arrow::array::OffsetSizeTrait> {
596+
/// The haystack array containing values to match against
597+
in_array: ArrayRef,
598+
/// DirectProbeFilter for O(1) encoded i128 quick rejection
599+
encoded_filter: DirectProbeFilter<Decimal128Type>,
600+
/// HashTable storing indices of long strings (>12 bytes) for Stage 2
601+
long_string_table: HashTable<usize>,
602+
/// Random state for consistent hashing
603+
state: RandomState,
604+
/// Whether all haystack strings are short (≤12 bytes) - enables fast path
605+
all_short: bool,
606+
_phantom: PhantomData<O>,
607+
}
608+
609+
impl<O: arrow::array::OffsetSizeTrait + 'static> Utf8TwoStageFilter<O> {
610+
pub(crate) fn try_new(in_array: ArrayRef) -> Result<Self> {
611+
use arrow::array::GenericStringArray;
612+
613+
let arr = in_array
614+
.as_any()
615+
.downcast_ref::<GenericStringArray<O>>()
616+
.expect("Utf8TwoStageFilter requires GenericStringArray");
617+
618+
let len = arr.len();
619+
let mut encoded_values = Vec::with_capacity(len);
620+
let state = RandomState::new();
621+
let mut long_string_table = HashTable::new();
622+
let mut all_short = true;
623+
624+
// Build encoded values and long string table
625+
for i in 0..len {
626+
if arr.is_null(i) {
627+
encoded_values.push(0);
628+
continue;
629+
}
630+
631+
let s = arr.value(i);
632+
let bytes = s.as_bytes();
633+
encoded_values.push(encode_string_as_i128(bytes));
634+
635+
if bytes.len() > INLINE_STRING_LEN {
636+
all_short = false;
637+
// Add to long string table for Stage 2 verification (with deduplication)
638+
let hash = state.hash_one(bytes);
639+
if long_string_table
640+
.find(hash, |&stored_idx| {
641+
arr.value(stored_idx).as_bytes() == bytes
642+
})
643+
.is_none()
644+
{
645+
long_string_table.insert_unique(hash, i, |&idx| {
646+
state.hash_one(arr.value(idx).as_bytes())
647+
});
648+
}
649+
}
650+
}
651+
652+
// Build DirectProbeFilter from encoded values
653+
let nulls = arr
654+
.nulls()
655+
.map(|n| arrow::buffer::NullBuffer::new(n.inner().clone()));
656+
let encoded_array: ArrayRef = Arc::new(PrimitiveArray::<Decimal128Type>::new(
657+
ScalarBuffer::from(encoded_values),
658+
nulls,
659+
));
660+
let encoded_filter =
661+
DirectProbeFilter::<Decimal128Type>::try_new(&encoded_array)?;
662+
663+
Ok(Self {
664+
in_array,
665+
encoded_filter,
666+
long_string_table,
667+
state,
668+
all_short,
669+
_phantom: PhantomData,
670+
})
671+
}
672+
}
673+
674+
impl<O: arrow::array::OffsetSizeTrait + 'static> StaticFilter for Utf8TwoStageFilter<O> {
675+
fn null_count(&self) -> usize {
676+
self.in_array.null_count()
677+
}
678+
679+
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
680+
use arrow::array::GenericStringArray;
681+
682+
handle_dictionary!(self, v, negated);
683+
684+
let needle_arr = v
685+
.as_any()
686+
.downcast_ref::<GenericStringArray<O>>()
687+
.expect("needle array type mismatch in Utf8TwoStageFilter");
688+
let haystack_arr = self
689+
.in_array
690+
.as_any()
691+
.downcast_ref::<GenericStringArray<O>>()
692+
.expect("haystack array type mismatch in Utf8TwoStageFilter");
693+
694+
let haystack_has_nulls = self.in_array.null_count() > 0;
695+
696+
if self.all_short {
697+
// Fast path: all haystack strings are short
698+
// Batch-encode all needles and do bulk lookup
699+
let needle_encoded: Vec<i128> = (0..needle_arr.len())
700+
.map(|i| {
701+
if needle_arr.is_null(i) {
702+
0
703+
} else {
704+
encode_string_as_i128(needle_arr.value(i).as_bytes())
705+
}
706+
})
707+
.collect();
708+
709+
// For short haystack, encoded match is definitive for short needles.
710+
// Long needles (>12 bytes) can never match, but their encoded form
711+
// won't match any short haystack encoding (different length field).
712+
return Ok(self.encoded_filter.contains_slice(
713+
&needle_encoded,
714+
needle_arr.nulls(),
715+
negated,
716+
));
717+
}
718+
719+
// Two-stage path: haystack has long strings
720+
Ok(super::result::build_in_list_result(
721+
v.len(),
722+
needle_arr.nulls(),
723+
haystack_has_nulls,
724+
negated,
725+
|i| {
726+
// SAFETY: i is in bounds [0, v.len()), guaranteed by build_in_list_result
727+
let needle_bytes = unsafe { needle_arr.value_unchecked(i) }.as_bytes();
728+
let encoded = encode_string_as_i128(needle_bytes);
729+
730+
// Stage 1: Quick rejection via encoded i128
731+
if !self.encoded_filter.contains_single(encoded) {
732+
return false;
733+
}
734+
735+
// Encoded match found
736+
let needle_len = needle_bytes.len();
737+
if needle_len <= INLINE_STRING_LEN {
738+
// Short needle: encoded contains all data, match is definitive
739+
// (If haystack had a long string with same prefix, its length
740+
// field would differ, so encoded wouldn't match)
741+
return true;
742+
}
743+
744+
// Stage 2: Long needle - verify with full string comparison
745+
let hash = self.state.hash_one(needle_bytes);
746+
self.long_string_table
747+
.find(hash, |&idx| {
748+
// SAFETY: idx was stored in try_new from valid indices into in_array
749+
unsafe { haystack_arr.value_unchecked(idx) }.as_bytes()
750+
== needle_bytes
751+
})
752+
.is_some()
753+
},
754+
))
755+
}
756+
}
757+
758+
/// Creates a two-stage filter for Utf8/LargeUtf8 arrays.
759+
/// Returns true if all non-null strings in a Utf8/LargeUtf8 array are ≤12 bytes.
760+
/// When false, the two-stage filter's Stage 1 cannot definitively match and the
761+
/// encoding overhead regresses performance vs the generic fallback.
762+
pub(crate) fn utf8_all_short_strings(array: &dyn Array) -> bool {
763+
use arrow::array::GenericStringArray;
764+
use arrow::datatypes::DataType;
765+
match array.data_type() {
766+
DataType::Utf8 => utf8_all_short_strings_impl::<i32>(
767+
array
768+
.as_any()
769+
.downcast_ref::<GenericStringArray<i32>>()
770+
.unwrap(),
771+
),
772+
DataType::LargeUtf8 => utf8_all_short_strings_impl::<i64>(
773+
array
774+
.as_any()
775+
.downcast_ref::<GenericStringArray<i64>>()
776+
.unwrap(),
777+
),
778+
_ => false,
779+
}
780+
}
781+
782+
fn utf8_all_short_strings_impl<O: arrow::array::OffsetSizeTrait>(
783+
arr: &arrow::array::GenericStringArray<O>,
784+
) -> bool {
785+
(0..arr.len()).all(|i| arr.is_null(i) || arr.value(i).len() <= INLINE_STRING_LEN)
786+
}
787+
788+
pub(crate) fn make_utf8_two_stage_filter(
789+
in_array: ArrayRef,
790+
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
791+
use arrow::datatypes::DataType;
792+
match in_array.data_type() {
793+
DataType::Utf8 => Ok(Arc::new(Utf8TwoStageFilter::<i32>::try_new(in_array)?)),
794+
DataType::LargeUtf8 => {
795+
Ok(Arc::new(Utf8TwoStageFilter::<i64>::try_new(in_array)?))
796+
}
797+
dt => datafusion_common::exec_err!(
798+
"Unsupported data type for Utf8 two-stage filter: {dt}"
799+
),
800+
}
801+
}

0 commit comments

Comments
 (0)