Skip to content

Commit 4ad5c3d

Browse files
authored
perf: Optimize strpos() for ASCII-only inputs (#20295)
The previous implementation had a fast path for ASCII-only inputs, but it was still relatively slow. Switch to using memchr::memchr() to find the first matching byte and then check the rest of the bytes by hand. This improves performance for ASCII inputs by 2x-4x on the built-in strpos benchmarks. ## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #20294. ## Are these changes tested? Yes, passes unit tests and SLT. ## Are there any user-facing changes? No.
1 parent 682da84 commit 4ad5c3d

4 files changed

Lines changed: 51 additions & 32 deletions

File tree

Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/functions/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ hex = { workspace = true, optional = true }
8282
itertools = { workspace = true }
8383
log = { workspace = true }
8484
md-5 = { version = "^0.10.0", optional = true }
85+
memchr = "2.8.0"
8586
num-traits = { workspace = true }
8687
rand = { workspace = true }
8788
regex = { workspace = true, optional = true }

datafusion/functions/benches/strpos.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@ use std::hint::black_box;
2727
use std::str::Chars;
2828
use std::sync::Arc;
2929

30-
/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with
31-
/// 4096 rows, each row containing a string with 128 random characters.
32-
/// around 10% of the rows are null, around 10% of the rows are non-ASCII.
30+
/// Returns a `Vec<ColumnarValue>` with two elements: a haystack array and a
31+
/// needle array. Each haystack is a random string of `str_len_chars`
32+
/// characters. Each needle is a random contiguous substring of its
33+
/// corresponding haystack (i.e., the needle is always present in the haystack).
34+
/// Around `null_density` fraction of rows are null and `utf8_density` fraction
35+
/// contain non-ASCII characters; the remaining rows are ASCII-only.
3336
fn gen_string_array(
3437
n_rows: usize,
3538
str_len_chars: usize,

datafusion/functions/src/unicode/strpos.rs

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion_expr::{
3232
Volatility,
3333
};
3434
use datafusion_macros::user_doc;
35+
use memchr::memchr;
3536

3637
#[user_doc(
3738
doc_section(label = "String Functions"),
@@ -179,6 +180,31 @@ fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
179180
}
180181
}
181182

183+
/// Find `needle` in `haystack` using `memchr` to quickly skip to positions
184+
/// where the first byte matches, then verify the remaining bytes. Using
185+
/// string::find is slower because it has significant per-call overhead that
186+
/// `memchr` does not, and strpos is often invoked many times on short inputs.
187+
/// Returns a 1-based position, or 0 if not found.
188+
/// Both inputs must be ASCII-only.
189+
fn find_ascii_substring(haystack: &[u8], needle: &[u8]) -> usize {
190+
let needle_len = needle.len();
191+
let first_byte = needle[0];
192+
let mut offset = 0;
193+
194+
while let Some(pos) = memchr(first_byte, &haystack[offset..]) {
195+
let start = offset + pos;
196+
if start + needle_len > haystack.len() {
197+
return 0;
198+
}
199+
if haystack[start..start + needle_len] == *needle {
200+
return start + 1;
201+
}
202+
offset = start + 1;
203+
}
204+
205+
0
206+
}
207+
182208
/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)
183209
/// strpos('high', 'ig') = 2
184210
/// The implementation uses UTF-8 code points as characters
@@ -198,37 +224,25 @@ where
198224
.zip(substring_iter)
199225
.map(|(string, substring)| match (string, substring) {
200226
(Some(string), Some(substring)) => {
201-
// If only ASCII characters are present, we can use the slide window method to find
202-
// the sub vector in the main vector. This is faster than string.find() method.
227+
if substring.is_empty() {
228+
return T::Native::from_usize(1);
229+
}
230+
231+
let substring_bytes = substring.as_bytes();
232+
let string_bytes = string.as_bytes();
233+
234+
if substring_bytes.len() > string_bytes.len() {
235+
return T::Native::from_usize(0);
236+
}
237+
203238
if ascii_only {
204-
// If the substring is empty, the result is 1.
205-
if substring.is_empty() {
206-
T::Native::from_usize(1)
207-
} else {
208-
T::Native::from_usize(
209-
string
210-
.as_bytes()
211-
.windows(substring.len())
212-
.position(|w| w == substring.as_bytes())
213-
.map(|x| x + 1)
214-
.unwrap_or(0),
215-
)
216-
}
239+
T::Native::from_usize(find_ascii_substring(
240+
string_bytes,
241+
substring_bytes,
242+
))
217243
} else {
218244
// For non-ASCII, use a single-pass search that tracks both
219245
// byte position and character position simultaneously
220-
if substring.is_empty() {
221-
return T::Native::from_usize(1);
222-
}
223-
224-
let substring_bytes = substring.as_bytes();
225-
let string_bytes = string.as_bytes();
226-
227-
if substring_bytes.len() > string_bytes.len() {
228-
return T::Native::from_usize(0);
229-
}
230-
231-
// Single pass: find substring while counting characters
232246
let mut char_pos = 0;
233247
for (byte_idx, _) in string.char_indices() {
234248
char_pos += 1;

0 commit comments

Comments
 (0)