Skip to content

Commit 4c195b4

Browse files
authored
perf: Optimize strpos() for scalar needle, plus optimize UTF-8 codepath (#20754)
## Which issue does this PR close? - Closes #20753. ## Rationale for this change This PR implements two mostly unrelated optimizations for `strpos`: 1. When the needle is scalar, we can build a single `memmem::Finder` and use it to search each row of the haystack. It turns out that this is significantly faster than using `memchr`, and the cost of constructing the finder is cheap because it is amortized over the batch. 2. We previously optimized strpos to use memchr for searching when both haystack and needle are ASCII-only (#20295). That was needlessly conservative: UTF-8 is self-stabilizing, so it should be safe to use `memchr` to search for matches for any combination of ASCII and UTF-8 needle and haystack. The performance improvement depends on a bunch of factors (ASCII vs. UTF-8, scalar vs array needle, length of haystack strings), but ranges from 5% for short ASCII strings with a scalar needle to 15x for long UTF-8 strings with a scalar needle. ## What changes are included in this PR? * Improve SLT test coverage for `strpos` * Refactor and extend `strpos` benchmarks to cover the scalar case * Implement optimizations described above * Code cleanup and refactoring for the `strpos` implementation ## Are these changes tested? Yes; new test cases and benchmarks added. ## Are there any user-facing changes? No.
1 parent 1cb4de4 commit 4c195b4

3 files changed

Lines changed: 460 additions & 249 deletions

File tree

datafusion/functions/benches/strpos.rs

Lines changed: 181 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -18,178 +18,216 @@
1818
use arrow::array::{StringArray, StringViewArray};
1919
use arrow::datatypes::{DataType, Field};
2020
use criterion::{Criterion, criterion_group, criterion_main};
21+
use datafusion_common::ScalarValue;
2122
use datafusion_common::config::ConfigOptions;
2223
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
2324
use rand::distr::Alphanumeric;
2425
use rand::prelude::StdRng;
2526
use rand::{Rng, SeedableRng};
2627
use std::hint::black_box;
27-
use std::str::Chars;
2828
use std::sync::Arc;
2929

30-
/// Returns a `Vec<ColumnarValue>` with two elements: a haystack array and a
31-
/// needle array. Each haystack is a random string of `str_len_chars`
32-
/// characters. Each needle is a random contiguous substring of its
33-
/// corresponding haystack (i.e., the needle is always present in the haystack).
34-
/// Around `null_density` fraction of rows are null and `utf8_density` fraction
35-
/// contain non-ASCII characters; the remaining rows are ASCII-only.
36-
fn gen_string_array(
37-
n_rows: usize,
30+
#[rustfmt::skip]
31+
const UTF8_CORPUS: &[char] = &[
32+
// Cyrillic (2 bytes each)
33+
'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'К', 'Л', 'М', 'Н', 'О', 'П', 'Р', 'С',
34+
'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Э', 'Ю', 'Я',
35+
// CJK (3 bytes each)
36+
'数', '据', '融', '合', '查', '询', '引', '擎', '优', '化', '执', '行', '计', '划',
37+
'表', '达',
38+
// Emoji (4 bytes each)
39+
'📊', '🔥', '🚀', '⚡', '🎯', '💡', '🔧', '📈',
40+
];
41+
const N_ROWS: usize = 8192;
42+
43+
/// Returns a random string of `len` characters. If `ascii` is true, the string
44+
/// is ASCII-only; otherwise it is drawn from `UTF8_CORPUS`.
45+
fn random_string(rng: &mut StdRng, len: usize, ascii: bool) -> String {
46+
if ascii {
47+
let value: Vec<u8> = rng.sample_iter(&Alphanumeric).take(len).collect();
48+
String::from_utf8(value).unwrap()
49+
} else {
50+
(0..len)
51+
.map(|_| UTF8_CORPUS[rng.random_range(0..UTF8_CORPUS.len())])
52+
.collect()
53+
}
54+
}
55+
56+
/// Wraps `strings` into either a `StringArray` or `StringViewArray`.
57+
fn to_columnar_value(
58+
strings: Vec<Option<String>>,
59+
is_string_view: bool,
60+
) -> ColumnarValue {
61+
if is_string_view {
62+
let arr: StringViewArray = strings.into_iter().collect();
63+
ColumnarValue::Array(Arc::new(arr))
64+
} else {
65+
let arr: StringArray = strings.into_iter().collect();
66+
ColumnarValue::Array(Arc::new(arr))
67+
}
68+
}
69+
70+
/// Returns haystack and needle, where both are arrays. Each needle is a
71+
/// contiguous substring of its corresponding haystack. Around `null_density`
72+
/// fraction of rows are null and `utf8_density` fraction contain non-ASCII
73+
/// characters.
74+
fn make_array_needle_args(
75+
rng: &mut StdRng,
3876
str_len_chars: usize,
3977
null_density: f32,
4078
utf8_density: f32,
41-
is_string_view: bool, // false -> StringArray, true -> StringViewArray
79+
is_string_view: bool,
4280
) -> Vec<ColumnarValue> {
43-
let mut rng = StdRng::seed_from_u64(42);
44-
let rng_ref = &mut rng;
45-
46-
let utf8 = "DatafusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with 1~4 bytes
47-
let corpus_char_count = utf8.chars().count();
48-
49-
let mut output_string_vec: Vec<Option<String>> = Vec::with_capacity(n_rows);
50-
let mut output_sub_string_vec: Vec<Option<String>> = Vec::with_capacity(n_rows);
51-
for _ in 0..n_rows {
52-
let rand_num = rng_ref.random::<f32>(); // [0.0, 1.0)
53-
if rand_num < null_density {
54-
output_sub_string_vec.push(None);
55-
output_string_vec.push(None);
56-
} else if rand_num < null_density + utf8_density {
57-
// Generate random UTF8 string
58-
let mut generated_string = String::with_capacity(str_len_chars);
59-
for _ in 0..str_len_chars {
60-
let idx = rng_ref.random_range(0..corpus_char_count);
61-
let char = utf8.chars().nth(idx).unwrap();
62-
generated_string.push(char);
63-
}
64-
output_sub_string_vec.push(Some(random_substring(generated_string.chars())));
65-
output_string_vec.push(Some(generated_string));
81+
let mut haystacks: Vec<Option<String>> = Vec::with_capacity(N_ROWS);
82+
let mut needles: Vec<Option<String>> = Vec::with_capacity(N_ROWS);
83+
for _ in 0..N_ROWS {
84+
let r = rng.random::<f32>();
85+
if r < null_density {
86+
haystacks.push(None);
87+
needles.push(None);
6688
} else {
67-
// Generate random ASCII-only string
68-
let value = rng_ref
89+
let ascii = r >= null_density + utf8_density;
90+
let s = random_string(rng, str_len_chars, ascii);
91+
needles.push(Some(random_substring(rng, &s)));
92+
haystacks.push(Some(s));
93+
}
94+
}
95+
96+
vec![
97+
to_columnar_value(haystacks, is_string_view),
98+
to_columnar_value(needles, is_string_view),
99+
]
100+
}
101+
102+
/// Returns haystack array with a fixed scalar needle inserted into each row.
103+
/// Around `null_density` fraction of rows are null and `utf8_density` fraction
104+
/// contain non-ASCII characters. The needle must be ASCII.
105+
fn make_scalar_needle_args(
106+
rng: &mut StdRng,
107+
str_len_chars: usize,
108+
needle: &str,
109+
null_density: f32,
110+
utf8_density: f32,
111+
is_string_view: bool,
112+
) -> Vec<ColumnarValue> {
113+
let needle_len = needle.len();
114+
assert!(
115+
str_len_chars >= needle_len,
116+
"str_len_chars must be >= needle length"
117+
);
118+
119+
let mut haystacks: Vec<Option<String>> = Vec::with_capacity(N_ROWS);
120+
for _ in 0..N_ROWS {
121+
let r = rng.random::<f32>();
122+
if r < null_density {
123+
haystacks.push(None);
124+
} else if r >= null_density + utf8_density {
125+
let mut value: Vec<u8> = (&mut *rng)
69126
.sample_iter(&Alphanumeric)
70127
.take(str_len_chars)
71128
.collect();
72-
let value = String::from_utf8(value).unwrap();
73-
output_sub_string_vec.push(Some(random_substring(value.chars())));
74-
output_string_vec.push(Some(value));
129+
let pos = rng.random_range(0..=str_len_chars - needle_len);
130+
value[pos..pos + needle_len].copy_from_slice(needle.as_bytes());
131+
haystacks.push(Some(String::from_utf8(value).unwrap()));
132+
} else {
133+
let mut s = random_string(rng, str_len_chars, false);
134+
let char_positions: Vec<usize> = s.char_indices().map(|(i, _)| i).collect();
135+
let insert_pos = if char_positions.len() > 1 {
136+
char_positions[rng.random_range(0..char_positions.len())]
137+
} else {
138+
0
139+
};
140+
s.insert_str(insert_pos, needle);
141+
haystacks.push(Some(s));
75142
}
76143
}
77144

78-
if is_string_view {
79-
let string_view_array: StringViewArray = output_string_vec.into_iter().collect();
80-
let sub_string_view_array: StringViewArray =
81-
output_sub_string_vec.into_iter().collect();
82-
vec![
83-
ColumnarValue::Array(Arc::new(string_view_array)),
84-
ColumnarValue::Array(Arc::new(sub_string_view_array)),
85-
]
86-
} else {
87-
let string_array: StringArray = output_string_vec.clone().into_iter().collect();
88-
let sub_string_array: StringArray = output_sub_string_vec.into_iter().collect();
89-
vec![
90-
ColumnarValue::Array(Arc::new(string_array)),
91-
ColumnarValue::Array(Arc::new(sub_string_array)),
92-
]
93-
}
145+
let needle_cv = ColumnarValue::Scalar(ScalarValue::Utf8(Some(needle.to_string())));
146+
vec![to_columnar_value(haystacks, is_string_view), needle_cv]
94147
}
95148

96-
fn random_substring(chars: Chars) -> String {
97-
// get the substring of a random length from the input string by byte unit
98-
let mut rng = StdRng::seed_from_u64(44);
99-
let count = chars.clone().count();
149+
/// Extracts a random contiguous substring from `s`.
150+
fn random_substring(rng: &mut StdRng, s: &str) -> String {
151+
let count = s.chars().count();
152+
153+
assert!(count > 0, "random_substring requires a non-empty string");
154+
if count == 1 {
155+
return s.to_string();
156+
}
157+
100158
let start = rng.random_range(0..count - 1);
101159
let end = rng.random_range(start + 1..count);
102-
chars
103-
.enumerate()
104-
.filter(|(i, _)| *i >= start && *i < end)
105-
.map(|(_, c)| c)
106-
.collect()
160+
s.chars().skip(start).take(end - start).collect()
161+
}
162+
163+
fn bench_strpos(
164+
c: &mut Criterion,
165+
name: &str,
166+
args: &[ColumnarValue],
167+
strpos: &datafusion_expr::ScalarUDF,
168+
) {
169+
let arg_fields = vec![Field::new("a", args[0].data_type(), true).into()];
170+
let return_field: Arc<Field> = Field::new("f", DataType::Int32, true).into();
171+
let config_options = Arc::new(ConfigOptions::default());
172+
173+
c.bench_function(name, |b| {
174+
b.iter(|| {
175+
black_box(strpos.invoke_with_args(ScalarFunctionArgs {
176+
args: args.to_vec(),
177+
arg_fields: arg_fields.clone(),
178+
number_rows: N_ROWS,
179+
return_field: Arc::clone(&return_field),
180+
config_options: Arc::clone(&config_options),
181+
}))
182+
})
183+
});
107184
}
108185

109186
fn criterion_benchmark(c: &mut Criterion) {
110-
// All benches are single batch run with 8192 rows
111187
let strpos = datafusion_functions::unicode::strpos();
188+
let mut rng = StdRng::seed_from_u64(42);
112189

113-
let n_rows = 8192;
114190
for str_len in [8, 32, 128, 4096] {
115-
// StringArray ASCII only
116-
let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false);
117-
let arg_fields =
118-
vec![Field::new("a", args_string_ascii[0].data_type(), true).into()];
119-
let return_field = Field::new("f", DataType::Int32, true).into();
120-
let config_options = Arc::new(ConfigOptions::default());
121-
122-
c.bench_function(
123-
&format!("strpos_StringArray_ascii_str_len_{str_len}"),
124-
|b| {
125-
b.iter(|| {
126-
black_box(strpos.invoke_with_args(ScalarFunctionArgs {
127-
args: args_string_ascii.clone(),
128-
arg_fields: arg_fields.clone(),
129-
number_rows: n_rows,
130-
return_field: Arc::clone(&return_field),
131-
config_options: Arc::clone(&config_options),
132-
}))
133-
})
134-
},
135-
);
136-
137-
// StringArray UTF8
138-
let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false);
139-
let arg_fields =
140-
vec![Field::new("a", args_string_utf8[0].data_type(), true).into()];
141-
let return_field = Field::new("f", DataType::Int32, true).into();
142-
c.bench_function(&format!("strpos_StringArray_utf8_str_len_{str_len}"), |b| {
143-
b.iter(|| {
144-
black_box(strpos.invoke_with_args(ScalarFunctionArgs {
145-
args: args_string_utf8.clone(),
146-
arg_fields: arg_fields.clone(),
147-
number_rows: n_rows,
148-
return_field: Arc::clone(&return_field),
149-
config_options: Arc::clone(&config_options),
150-
}))
151-
})
152-
});
153-
154-
// StringViewArray ASCII only
155-
let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true);
156-
let arg_fields =
157-
vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()];
158-
let return_field = Field::new("f", DataType::Int32, true).into();
159-
c.bench_function(
160-
&format!("strpos_StringViewArray_ascii_str_len_{str_len}"),
161-
|b| {
162-
b.iter(|| {
163-
black_box(strpos.invoke_with_args(ScalarFunctionArgs {
164-
args: args_string_view_ascii.clone(),
165-
arg_fields: arg_fields.clone(),
166-
number_rows: n_rows,
167-
return_field: Arc::clone(&return_field),
168-
config_options: Arc::clone(&config_options),
169-
}))
170-
})
171-
},
172-
);
173-
174-
// StringViewArray UTF8
175-
let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true);
176-
let arg_fields =
177-
vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()];
178-
let return_field = Field::new("f", DataType::Int32, true).into();
179-
c.bench_function(
180-
&format!("strpos_StringViewArray_utf8_str_len_{str_len}"),
181-
|b| {
182-
b.iter(|| {
183-
black_box(strpos.invoke_with_args(ScalarFunctionArgs {
184-
args: args_string_view_utf8.clone(),
185-
arg_fields: arg_fields.clone(),
186-
number_rows: n_rows,
187-
return_field: Arc::clone(&return_field),
188-
config_options: Arc::clone(&config_options),
189-
}))
190-
})
191-
},
192-
);
191+
// Array needle benchmarks
192+
for (label, utf8_density, is_view) in [
193+
("StringArray_ascii", 0.0, false),
194+
("StringArray_utf8", 0.5, false),
195+
("StringViewArray_ascii", 0.0, true),
196+
("StringViewArray_utf8", 0.5, true),
197+
] {
198+
let args =
199+
make_array_needle_args(&mut rng, str_len, 0.1, utf8_density, is_view);
200+
bench_strpos(
201+
c,
202+
&format!("strpos_{label}_str_len_{str_len}"),
203+
&args,
204+
strpos.as_ref(),
205+
);
206+
}
207+
208+
// Scalar needle benchmarks
209+
let needle = "xyz";
210+
for (label, utf8_density, is_view) in [
211+
("StringArray_scalar_needle_ascii", 0.0, false),
212+
("StringArray_scalar_needle_utf8", 0.5, false),
213+
("StringViewArray_scalar_needle_ascii", 0.0, true),
214+
("StringViewArray_scalar_needle_utf8", 0.5, true),
215+
] {
216+
let args = make_scalar_needle_args(
217+
&mut rng,
218+
str_len,
219+
needle,
220+
0.1,
221+
utf8_density,
222+
is_view,
223+
);
224+
bench_strpos(
225+
c,
226+
&format!("strpos_{label}_str_len_{str_len}"),
227+
&args,
228+
strpos.as_ref(),
229+
);
230+
}
193231
}
194232
}
195233

0 commit comments

Comments
 (0)