|
18 | 18 | use arrow::array::{StringArray, StringViewArray}; |
19 | 19 | use arrow::datatypes::{DataType, Field}; |
20 | 20 | use criterion::{Criterion, criterion_group, criterion_main}; |
| 21 | +use datafusion_common::ScalarValue; |
21 | 22 | use datafusion_common::config::ConfigOptions; |
22 | 23 | use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; |
23 | 24 | use rand::distr::Alphanumeric; |
24 | 25 | use rand::prelude::StdRng; |
25 | 26 | use rand::{Rng, SeedableRng}; |
26 | 27 | use std::hint::black_box; |
27 | | -use std::str::Chars; |
28 | 28 | use std::sync::Arc; |
29 | 29 |
|
30 | | -/// Returns a `Vec<ColumnarValue>` with two elements: a haystack array and a |
31 | | -/// needle array. Each haystack is a random string of `str_len_chars` |
32 | | -/// characters. Each needle is a random contiguous substring of its |
33 | | -/// corresponding haystack (i.e., the needle is always present in the haystack). |
34 | | -/// Around `null_density` fraction of rows are null and `utf8_density` fraction |
35 | | -/// contain non-ASCII characters; the remaining rows are ASCII-only. |
36 | | -fn gen_string_array( |
37 | | - n_rows: usize, |
| 30 | +#[rustfmt::skip] |
| 31 | +const UTF8_CORPUS: &[char] = &[ |
| 32 | + // Cyrillic (2 bytes each) |
| 33 | + 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'К', 'Л', 'М', 'Н', 'О', 'П', 'Р', 'С', |
| 34 | + 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Э', 'Ю', 'Я', |
| 35 | + // CJK (3 bytes each) |
| 36 | + '数', '据', '融', '合', '查', '询', '引', '擎', '优', '化', '执', '行', '计', '划', |
| 37 | + '表', '达', |
| 38 | + // Emoji (4 bytes each) |
| 39 | + '📊', '🔥', '🚀', '⚡', '🎯', '💡', '🔧', '📈', |
| 40 | +]; |
| 41 | +const N_ROWS: usize = 8192; |
| 42 | + |
| 43 | +/// Returns a random string of `len` characters. If `ascii` is true, the string |
| 44 | +/// is ASCII-only; otherwise it is drawn from `UTF8_CORPUS`. |
| 45 | +fn random_string(rng: &mut StdRng, len: usize, ascii: bool) -> String { |
| 46 | + if ascii { |
| 47 | + let value: Vec<u8> = rng.sample_iter(&Alphanumeric).take(len).collect(); |
| 48 | + String::from_utf8(value).unwrap() |
| 49 | + } else { |
| 50 | + (0..len) |
| 51 | + .map(|_| UTF8_CORPUS[rng.random_range(0..UTF8_CORPUS.len())]) |
| 52 | + .collect() |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +/// Wraps `strings` into either a `StringArray` or `StringViewArray`. |
| 57 | +fn to_columnar_value( |
| 58 | + strings: Vec<Option<String>>, |
| 59 | + is_string_view: bool, |
| 60 | +) -> ColumnarValue { |
| 61 | + if is_string_view { |
| 62 | + let arr: StringViewArray = strings.into_iter().collect(); |
| 63 | + ColumnarValue::Array(Arc::new(arr)) |
| 64 | + } else { |
| 65 | + let arr: StringArray = strings.into_iter().collect(); |
| 66 | + ColumnarValue::Array(Arc::new(arr)) |
| 67 | + } |
| 68 | +} |
| 69 | + |
| 70 | +/// Returns haystack and needle, where both are arrays. Each needle is a |
| 71 | +/// contiguous substring of its corresponding haystack. Around `null_density` |
| 72 | +/// fraction of rows are null and `utf8_density` fraction contain non-ASCII |
| 73 | +/// characters. |
| 74 | +fn make_array_needle_args( |
| 75 | + rng: &mut StdRng, |
38 | 76 | str_len_chars: usize, |
39 | 77 | null_density: f32, |
40 | 78 | utf8_density: f32, |
41 | | - is_string_view: bool, // false -> StringArray, true -> StringViewArray |
| 79 | + is_string_view: bool, |
42 | 80 | ) -> Vec<ColumnarValue> { |
43 | | - let mut rng = StdRng::seed_from_u64(42); |
44 | | - let rng_ref = &mut rng; |
45 | | - |
46 | | - let utf8 = "DatafusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with 1~4 bytes |
47 | | - let corpus_char_count = utf8.chars().count(); |
48 | | - |
49 | | - let mut output_string_vec: Vec<Option<String>> = Vec::with_capacity(n_rows); |
50 | | - let mut output_sub_string_vec: Vec<Option<String>> = Vec::with_capacity(n_rows); |
51 | | - for _ in 0..n_rows { |
52 | | - let rand_num = rng_ref.random::<f32>(); // [0.0, 1.0) |
53 | | - if rand_num < null_density { |
54 | | - output_sub_string_vec.push(None); |
55 | | - output_string_vec.push(None); |
56 | | - } else if rand_num < null_density + utf8_density { |
57 | | - // Generate random UTF8 string |
58 | | - let mut generated_string = String::with_capacity(str_len_chars); |
59 | | - for _ in 0..str_len_chars { |
60 | | - let idx = rng_ref.random_range(0..corpus_char_count); |
61 | | - let char = utf8.chars().nth(idx).unwrap(); |
62 | | - generated_string.push(char); |
63 | | - } |
64 | | - output_sub_string_vec.push(Some(random_substring(generated_string.chars()))); |
65 | | - output_string_vec.push(Some(generated_string)); |
| 81 | + let mut haystacks: Vec<Option<String>> = Vec::with_capacity(N_ROWS); |
| 82 | + let mut needles: Vec<Option<String>> = Vec::with_capacity(N_ROWS); |
| 83 | + for _ in 0..N_ROWS { |
| 84 | + let r = rng.random::<f32>(); |
| 85 | + if r < null_density { |
| 86 | + haystacks.push(None); |
| 87 | + needles.push(None); |
66 | 88 | } else { |
67 | | - // Generate random ASCII-only string |
68 | | - let value = rng_ref |
| 89 | + let ascii = r >= null_density + utf8_density; |
| 90 | + let s = random_string(rng, str_len_chars, ascii); |
| 91 | + needles.push(Some(random_substring(rng, &s))); |
| 92 | + haystacks.push(Some(s)); |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + vec![ |
| 97 | + to_columnar_value(haystacks, is_string_view), |
| 98 | + to_columnar_value(needles, is_string_view), |
| 99 | + ] |
| 100 | +} |
| 101 | + |
| 102 | +/// Returns haystack array with a fixed scalar needle inserted into each row. |
| 103 | +/// Around `null_density` fraction of rows are null and `utf8_density` fraction |
| 104 | +/// contain non-ASCII characters. The needle must be ASCII. |
| 105 | +fn make_scalar_needle_args( |
| 106 | + rng: &mut StdRng, |
| 107 | + str_len_chars: usize, |
| 108 | + needle: &str, |
| 109 | + null_density: f32, |
| 110 | + utf8_density: f32, |
| 111 | + is_string_view: bool, |
| 112 | +) -> Vec<ColumnarValue> { |
| 113 | + let needle_len = needle.len(); |
| 114 | + assert!( |
| 115 | + str_len_chars >= needle_len, |
| 116 | + "str_len_chars must be >= needle length" |
| 117 | + ); |
| 118 | + |
| 119 | + let mut haystacks: Vec<Option<String>> = Vec::with_capacity(N_ROWS); |
| 120 | + for _ in 0..N_ROWS { |
| 121 | + let r = rng.random::<f32>(); |
| 122 | + if r < null_density { |
| 123 | + haystacks.push(None); |
| 124 | + } else if r >= null_density + utf8_density { |
| 125 | + let mut value: Vec<u8> = (&mut *rng) |
69 | 126 | .sample_iter(&Alphanumeric) |
70 | 127 | .take(str_len_chars) |
71 | 128 | .collect(); |
72 | | - let value = String::from_utf8(value).unwrap(); |
73 | | - output_sub_string_vec.push(Some(random_substring(value.chars()))); |
74 | | - output_string_vec.push(Some(value)); |
| 129 | + let pos = rng.random_range(0..=str_len_chars - needle_len); |
| 130 | + value[pos..pos + needle_len].copy_from_slice(needle.as_bytes()); |
| 131 | + haystacks.push(Some(String::from_utf8(value).unwrap())); |
| 132 | + } else { |
| 133 | + let mut s = random_string(rng, str_len_chars, false); |
| 134 | + let char_positions: Vec<usize> = s.char_indices().map(|(i, _)| i).collect(); |
| 135 | + let insert_pos = if char_positions.len() > 1 { |
| 136 | + char_positions[rng.random_range(0..char_positions.len())] |
| 137 | + } else { |
| 138 | + 0 |
| 139 | + }; |
| 140 | + s.insert_str(insert_pos, needle); |
| 141 | + haystacks.push(Some(s)); |
75 | 142 | } |
76 | 143 | } |
77 | 144 |
|
78 | | - if is_string_view { |
79 | | - let string_view_array: StringViewArray = output_string_vec.into_iter().collect(); |
80 | | - let sub_string_view_array: StringViewArray = |
81 | | - output_sub_string_vec.into_iter().collect(); |
82 | | - vec![ |
83 | | - ColumnarValue::Array(Arc::new(string_view_array)), |
84 | | - ColumnarValue::Array(Arc::new(sub_string_view_array)), |
85 | | - ] |
86 | | - } else { |
87 | | - let string_array: StringArray = output_string_vec.clone().into_iter().collect(); |
88 | | - let sub_string_array: StringArray = output_sub_string_vec.into_iter().collect(); |
89 | | - vec![ |
90 | | - ColumnarValue::Array(Arc::new(string_array)), |
91 | | - ColumnarValue::Array(Arc::new(sub_string_array)), |
92 | | - ] |
93 | | - } |
| 145 | + let needle_cv = ColumnarValue::Scalar(ScalarValue::Utf8(Some(needle.to_string()))); |
| 146 | + vec![to_columnar_value(haystacks, is_string_view), needle_cv] |
94 | 147 | } |
95 | 148 |
|
96 | | -fn random_substring(chars: Chars) -> String { |
97 | | - // get the substring of a random length from the input string by byte unit |
98 | | - let mut rng = StdRng::seed_from_u64(44); |
99 | | - let count = chars.clone().count(); |
| 149 | +/// Extracts a random contiguous substring from `s`. |
| 150 | +fn random_substring(rng: &mut StdRng, s: &str) -> String { |
| 151 | + let count = s.chars().count(); |
| 152 | + |
| 153 | + assert!(count > 0, "random_substring requires a non-empty string"); |
| 154 | + if count == 1 { |
| 155 | + return s.to_string(); |
| 156 | + } |
| 157 | + |
100 | 158 | let start = rng.random_range(0..count - 1); |
101 | 159 | let end = rng.random_range(start + 1..count); |
102 | | - chars |
103 | | - .enumerate() |
104 | | - .filter(|(i, _)| *i >= start && *i < end) |
105 | | - .map(|(_, c)| c) |
106 | | - .collect() |
| 160 | + s.chars().skip(start).take(end - start).collect() |
| 161 | +} |
| 162 | + |
| 163 | +fn bench_strpos( |
| 164 | + c: &mut Criterion, |
| 165 | + name: &str, |
| 166 | + args: &[ColumnarValue], |
| 167 | + strpos: &datafusion_expr::ScalarUDF, |
| 168 | +) { |
| 169 | + let arg_fields = vec![Field::new("a", args[0].data_type(), true).into()]; |
| 170 | + let return_field: Arc<Field> = Field::new("f", DataType::Int32, true).into(); |
| 171 | + let config_options = Arc::new(ConfigOptions::default()); |
| 172 | + |
| 173 | + c.bench_function(name, |b| { |
| 174 | + b.iter(|| { |
| 175 | + black_box(strpos.invoke_with_args(ScalarFunctionArgs { |
| 176 | + args: args.to_vec(), |
| 177 | + arg_fields: arg_fields.clone(), |
| 178 | + number_rows: N_ROWS, |
| 179 | + return_field: Arc::clone(&return_field), |
| 180 | + config_options: Arc::clone(&config_options), |
| 181 | + })) |
| 182 | + }) |
| 183 | + }); |
107 | 184 | } |
108 | 185 |
|
109 | 186 | fn criterion_benchmark(c: &mut Criterion) { |
110 | | - // All benches are single batch run with 8192 rows |
111 | 187 | let strpos = datafusion_functions::unicode::strpos(); |
| 188 | + let mut rng = StdRng::seed_from_u64(42); |
112 | 189 |
|
113 | | - let n_rows = 8192; |
114 | 190 | for str_len in [8, 32, 128, 4096] { |
115 | | - // StringArray ASCII only |
116 | | - let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); |
117 | | - let arg_fields = |
118 | | - vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; |
119 | | - let return_field = Field::new("f", DataType::Int32, true).into(); |
120 | | - let config_options = Arc::new(ConfigOptions::default()); |
121 | | - |
122 | | - c.bench_function( |
123 | | - &format!("strpos_StringArray_ascii_str_len_{str_len}"), |
124 | | - |b| { |
125 | | - b.iter(|| { |
126 | | - black_box(strpos.invoke_with_args(ScalarFunctionArgs { |
127 | | - args: args_string_ascii.clone(), |
128 | | - arg_fields: arg_fields.clone(), |
129 | | - number_rows: n_rows, |
130 | | - return_field: Arc::clone(&return_field), |
131 | | - config_options: Arc::clone(&config_options), |
132 | | - })) |
133 | | - }) |
134 | | - }, |
135 | | - ); |
136 | | - |
137 | | - // StringArray UTF8 |
138 | | - let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); |
139 | | - let arg_fields = |
140 | | - vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; |
141 | | - let return_field = Field::new("f", DataType::Int32, true).into(); |
142 | | - c.bench_function(&format!("strpos_StringArray_utf8_str_len_{str_len}"), |b| { |
143 | | - b.iter(|| { |
144 | | - black_box(strpos.invoke_with_args(ScalarFunctionArgs { |
145 | | - args: args_string_utf8.clone(), |
146 | | - arg_fields: arg_fields.clone(), |
147 | | - number_rows: n_rows, |
148 | | - return_field: Arc::clone(&return_field), |
149 | | - config_options: Arc::clone(&config_options), |
150 | | - })) |
151 | | - }) |
152 | | - }); |
153 | | - |
154 | | - // StringViewArray ASCII only |
155 | | - let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); |
156 | | - let arg_fields = |
157 | | - vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; |
158 | | - let return_field = Field::new("f", DataType::Int32, true).into(); |
159 | | - c.bench_function( |
160 | | - &format!("strpos_StringViewArray_ascii_str_len_{str_len}"), |
161 | | - |b| { |
162 | | - b.iter(|| { |
163 | | - black_box(strpos.invoke_with_args(ScalarFunctionArgs { |
164 | | - args: args_string_view_ascii.clone(), |
165 | | - arg_fields: arg_fields.clone(), |
166 | | - number_rows: n_rows, |
167 | | - return_field: Arc::clone(&return_field), |
168 | | - config_options: Arc::clone(&config_options), |
169 | | - })) |
170 | | - }) |
171 | | - }, |
172 | | - ); |
173 | | - |
174 | | - // StringViewArray UTF8 |
175 | | - let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); |
176 | | - let arg_fields = |
177 | | - vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; |
178 | | - let return_field = Field::new("f", DataType::Int32, true).into(); |
179 | | - c.bench_function( |
180 | | - &format!("strpos_StringViewArray_utf8_str_len_{str_len}"), |
181 | | - |b| { |
182 | | - b.iter(|| { |
183 | | - black_box(strpos.invoke_with_args(ScalarFunctionArgs { |
184 | | - args: args_string_view_utf8.clone(), |
185 | | - arg_fields: arg_fields.clone(), |
186 | | - number_rows: n_rows, |
187 | | - return_field: Arc::clone(&return_field), |
188 | | - config_options: Arc::clone(&config_options), |
189 | | - })) |
190 | | - }) |
191 | | - }, |
192 | | - ); |
| 191 | + // Array needle benchmarks |
| 192 | + for (label, utf8_density, is_view) in [ |
| 193 | + ("StringArray_ascii", 0.0, false), |
| 194 | + ("StringArray_utf8", 0.5, false), |
| 195 | + ("StringViewArray_ascii", 0.0, true), |
| 196 | + ("StringViewArray_utf8", 0.5, true), |
| 197 | + ] { |
| 198 | + let args = |
| 199 | + make_array_needle_args(&mut rng, str_len, 0.1, utf8_density, is_view); |
| 200 | + bench_strpos( |
| 201 | + c, |
| 202 | + &format!("strpos_{label}_str_len_{str_len}"), |
| 203 | + &args, |
| 204 | + strpos.as_ref(), |
| 205 | + ); |
| 206 | + } |
| 207 | + |
| 208 | + // Scalar needle benchmarks |
| 209 | + let needle = "xyz"; |
| 210 | + for (label, utf8_density, is_view) in [ |
| 211 | + ("StringArray_scalar_needle_ascii", 0.0, false), |
| 212 | + ("StringArray_scalar_needle_utf8", 0.5, false), |
| 213 | + ("StringViewArray_scalar_needle_ascii", 0.0, true), |
| 214 | + ("StringViewArray_scalar_needle_utf8", 0.5, true), |
| 215 | + ] { |
| 216 | + let args = make_scalar_needle_args( |
| 217 | + &mut rng, |
| 218 | + str_len, |
| 219 | + needle, |
| 220 | + 0.1, |
| 221 | + utf8_density, |
| 222 | + is_view, |
| 223 | + ); |
| 224 | + bench_strpos( |
| 225 | + c, |
| 226 | + &format!("strpos_{label}_str_len_{str_len}"), |
| 227 | + &args, |
| 228 | + strpos.as_ref(), |
| 229 | + ); |
| 230 | + } |
193 | 231 | } |
194 | 232 | } |
195 | 233 |
|
|
0 commit comments