Skip to content

Commit aab4263

Browse files
authored
perf(substr_index): speed up scalar and Utf8View (#21754)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #20585. ## Rationale for this change `substr_index` is part of the Utf8View epic. This function was still missing the Utf8View optimization path, and it also did extra work when delimiter and count were constant scalars. This PR adds the same kind of optimization used in similar string work and includes benchmark coverage. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Return the input string type directly, including `Utf8View`. - Add a zero-copy `Utf8View` path for `substr_index`. - Add a scalar fast path when delimiter and count are constant. - Keep the hot path specialized after checking a cleaner shared-helper version and benchmarking it. - Add SQL coverage for `Utf8View` type and value results. - Add unit tests for the scalar fast path, sliced `Utf8View` arrays, and the unchanged original-view case. - Extend the benchmark to cover `Utf8`, `Utf8View`, array arguments, scalar delimiter/count, and positive and negative counts. <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? No <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 89e14f1 commit aab4263

3 files changed

Lines changed: 749 additions & 136 deletions

File tree

datafusion/functions/benches/substr_index.rs

Lines changed: 143 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,20 @@
1818
use std::hint::black_box;
1919
use std::sync::Arc;
2020

21-
use arrow::array::{ArrayRef, Int64Array, StringArray};
21+
use arrow::array::{ArrayRef, Int64Array, StringArray, StringViewArray};
2222
use arrow::datatypes::{DataType, Field};
2323
use criterion::{Criterion, criterion_group, criterion_main};
24-
use datafusion_common::config::ConfigOptions;
24+
use datafusion_common::{ScalarValue, config::ConfigOptions};
2525
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
2626
use datafusion_functions::unicode::substr_index;
2727
use rand::Rng;
28+
use rand::SeedableRng;
2829
use rand::distr::{Alphanumeric, Uniform};
2930
use rand::prelude::Distribution;
31+
use rand::rngs::StdRng;
32+
33+
const ARRAY_DATA_SEED: u64 = 0x5EED_AAAA;
34+
const SCALAR_DATA_SEED: u64 = 0x5EED_BBBB;
3035

3136
struct Filter<Dist, Test> {
3237
dist: Dist,
@@ -48,28 +53,72 @@ where
4853
}
4954
}
5055

51-
fn data(
56+
#[derive(Clone, Copy)]
57+
enum StringRep {
58+
Utf8,
59+
Utf8View,
60+
}
61+
62+
impl StringRep {
63+
fn name(self) -> &'static str {
64+
match self {
65+
Self::Utf8 => "utf8",
66+
Self::Utf8View => "utf8view",
67+
}
68+
}
69+
70+
fn data_type(self) -> DataType {
71+
match self {
72+
Self::Utf8 => DataType::Utf8,
73+
Self::Utf8View => DataType::Utf8View,
74+
}
75+
}
76+
77+
fn array(self, values: &[String]) -> ArrayRef {
78+
match self {
79+
Self::Utf8 => Arc::new(StringArray::from(values.to_vec())) as ArrayRef,
80+
Self::Utf8View => Arc::new(
81+
values
82+
.iter()
83+
.map(|value| Some(value.as_str()))
84+
.collect::<StringViewArray>(),
85+
) as ArrayRef,
86+
}
87+
}
88+
89+
fn scalar(self, value: &str) -> ScalarValue {
90+
match self {
91+
Self::Utf8 => ScalarValue::Utf8(Some(value.to_string())),
92+
Self::Utf8View => ScalarValue::Utf8View(Some(value.to_string())),
93+
}
94+
}
95+
}
96+
97+
fn random_token<R: Rng + ?Sized>(rng: &mut R, len: usize) -> String {
98+
rng.sample_iter(&Alphanumeric)
99+
.take(len)
100+
.map(char::from)
101+
.collect()
102+
}
103+
104+
fn array_data(
52105
batch_size: usize,
53106
single_char_delimiter: bool,
54-
) -> (StringArray, StringArray, Int64Array) {
55-
let dist = Filter {
56-
dist: Uniform::new(-4, 5),
107+
) -> (Vec<String>, Vec<String>, Vec<i64>) {
108+
let count_dist = Filter {
109+
dist: Uniform::new(-4, 5).expect("valid count distribution"),
57110
test: |x: &i64| x != &0,
58111
};
59-
let mut rng = rand::rng();
60-
let mut strings: Vec<String> = vec![];
61-
let mut delimiters: Vec<String> = vec![];
62-
let mut counts: Vec<i64> = vec![];
112+
let mut rng = StdRng::seed_from_u64(ARRAY_DATA_SEED);
113+
let mut strings = Vec::with_capacity(batch_size);
114+
let mut delimiters = Vec::with_capacity(batch_size);
115+
let mut counts = Vec::with_capacity(batch_size);
63116

64117
for _ in 0..batch_size {
65118
let length = rng.random_range(20..50);
66-
let base: String = (&mut rng)
67-
.sample_iter(&Alphanumeric)
68-
.take(length)
69-
.map(char::from)
70-
.collect();
119+
let base = random_token(&mut rng, length);
71120

72-
let (string_value, delimiter): (String, String) = if single_char_delimiter {
121+
let (string_value, delimiter) = if single_char_delimiter {
73122
let char_idx = rng.random_range(0..base.chars().count());
74123
let delimiter = base.chars().nth(char_idx).unwrap().to_string();
75124
(base, delimiter)
@@ -80,7 +129,6 @@ fn data(
80129

81130
let delimiter_count = rng.random_range(1..4);
82131
let mut result = String::new();
83-
84132
for i in 0..delimiter_count {
85133
result.push_str(&base);
86134
if i < delimiter_count - 1 {
@@ -90,32 +138,37 @@ fn data(
90138
(result, delimiter)
91139
};
92140

93-
let count = rng.sample(dist.dist.unwrap());
94-
95141
strings.push(string_value);
96142
delimiters.push(delimiter);
97-
counts.push(count);
143+
counts.push(count_dist.sample(&mut rng));
144+
}
145+
146+
(strings, delimiters, counts)
147+
}
148+
149+
fn scalar_data(batch_size: usize, delimiter: &str) -> Vec<String> {
150+
let mut rng = StdRng::seed_from_u64(SCALAR_DATA_SEED);
151+
let mut strings = Vec::with_capacity(batch_size);
152+
153+
for _ in 0..batch_size {
154+
let left_len = rng.random_range(12..24);
155+
let middle_len = rng.random_range(12..24);
156+
let right_len = rng.random_range(12..24);
157+
let left = random_token(&mut rng, left_len);
158+
let middle = random_token(&mut rng, middle_len);
159+
let right = random_token(&mut rng, right_len);
160+
strings.push(format!("{left}{delimiter}{middle}{delimiter}{right}"));
98161
}
99162

100-
(
101-
StringArray::from(strings),
102-
StringArray::from(delimiters),
103-
Int64Array::from(counts),
104-
)
163+
strings
105164
}
106165

107166
fn run_benchmark(
108167
b: &mut criterion::Bencher,
109-
strings: StringArray,
110-
delimiters: StringArray,
111-
counts: Int64Array,
112-
batch_size: usize,
168+
args: &[ColumnarValue],
169+
return_type: &DataType,
170+
number_rows: usize,
113171
) {
114-
let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef);
115-
let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef);
116-
let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef);
117-
118-
let args = vec![strings, delimiters, counts];
119172
let arg_fields = args
120173
.iter()
121174
.enumerate()
@@ -129,10 +182,10 @@ fn run_benchmark(
129182
black_box(
130183
substr_index()
131184
.invoke_with_args(ScalarFunctionArgs {
132-
args: args.clone(),
185+
args: args.to_vec(),
133186
arg_fields: arg_fields.clone(),
134-
number_rows: batch_size,
135-
return_field: Field::new("f", DataType::Utf8, true).into(),
187+
number_rows,
188+
return_field: Field::new("f", return_type.clone(), true).into(),
136189
config_options: Arc::clone(&config_options),
137190
})
138191
.expect("substr_index should work on valid values"),
@@ -142,22 +195,62 @@ fn run_benchmark(
142195

143196
fn criterion_benchmark(c: &mut Criterion) {
144197
let mut group = c.benchmark_group("substr_index");
145-
146198
let batch_sizes = [100, 1000, 10_000];
147199

148200
for batch_size in batch_sizes {
149-
group.bench_function(
150-
format!("substr_index_{batch_size}_single_delimiter"),
151-
|b| {
152-
let (strings, delimiters, counts) = data(batch_size, true);
153-
run_benchmark(b, strings, delimiters, counts, batch_size);
154-
},
155-
);
156-
157-
group.bench_function(format!("substr_index_{batch_size}_long_delimiter"), |b| {
158-
let (strings, delimiters, counts) = data(batch_size, false);
159-
run_benchmark(b, strings, delimiters, counts, batch_size);
160-
});
201+
for rep in [StringRep::Utf8, StringRep::Utf8View] {
202+
let rep_name = rep.name();
203+
204+
group.bench_function(
205+
format!("substr_index_{rep_name}_{batch_size}_array_single_delimiter"),
206+
|b| {
207+
let (strings, delimiters, counts) = array_data(batch_size, true);
208+
let args = vec![
209+
ColumnarValue::Array(rep.array(&strings)),
210+
ColumnarValue::Array(rep.array(&delimiters)),
211+
ColumnarValue::Array(
212+
Arc::new(Int64Array::from(counts)) as ArrayRef
213+
),
214+
];
215+
run_benchmark(b, &args, &rep.data_type(), batch_size);
216+
},
217+
);
218+
219+
group.bench_function(
220+
format!("substr_index_{rep_name}_{batch_size}_array_long_delimiter"),
221+
|b| {
222+
let (strings, delimiters, counts) = array_data(batch_size, false);
223+
let args = vec![
224+
ColumnarValue::Array(rep.array(&strings)),
225+
ColumnarValue::Array(rep.array(&delimiters)),
226+
ColumnarValue::Array(
227+
Arc::new(Int64Array::from(counts)) as ArrayRef
228+
),
229+
];
230+
run_benchmark(b, &args, &rep.data_type(), batch_size);
231+
},
232+
);
233+
234+
for (name, delimiter, count) in [
235+
("single_delimiter_pos", ".", 1_i64),
236+
("single_delimiter_neg", ".", -1_i64),
237+
("long_delimiter_pos", "|||", 1_i64),
238+
("long_delimiter_neg", "|||", -1_i64),
239+
] {
240+
group.bench_function(
241+
format!("substr_index_{rep_name}_{batch_size}_scalar_{name}"),
242+
|b| {
243+
let strings = scalar_data(batch_size, delimiter);
244+
let args = vec![
245+
ColumnarValue::Array(rep.array(&strings)),
246+
ColumnarValue::Scalar(rep.scalar(delimiter)),
247+
ColumnarValue::Scalar(ScalarValue::Int64(Some(count))),
248+
];
249+
run_benchmark(b, &args, &rep.data_type(), batch_size);
250+
},
251+
);
252+
}
253+
}
161254
}
162255

163256
group.finish();

0 commit comments

Comments
 (0)