Skip to content

Commit 9de1253

Browse files
authored
Update repeat UDF to emit utf8view when input is utf8view (#20645)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> Part of #20585 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Functions ideally should emit strings in the same format as the input and previously the repeat function was emitting using utf8 for input that was in utf8view. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Code, tests ## Are these changes tested? Yes <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent e74b83a commit 9de1253

2 files changed

Lines changed: 95 additions & 27 deletions

File tree

datafusion/functions/src/string/repeat.rs

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::sync::Arc;
2020
use crate::utils::utf8_to_str_type;
2121
use arrow::array::{
2222
Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
23-
OffsetSizeTrait, StringArrayType, StringViewArray,
23+
StringArrayType, StringLikeArrayBuilder, StringViewArray, StringViewBuilder,
2424
};
2525
use arrow::datatypes::DataType;
2626
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
@@ -91,6 +91,9 @@ impl ScalarUDFImpl for RepeatFunc {
9191
}
9292

9393
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94+
if arg_types[0] == Utf8View {
95+
return Ok(Utf8View);
96+
}
9497
utf8_to_str_type(&arg_types[0], "repeat")
9598
}
9699

@@ -126,13 +129,12 @@ impl ScalarUDFImpl for RepeatFunc {
126129
};
127130

128131
let result = match string_scalar {
129-
ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => {
130-
ScalarValue::Utf8(Some(compute_repeat(
131-
s,
132-
count,
133-
i32::MAX as usize,
134-
)?))
135-
}
132+
ScalarValue::Utf8View(Some(s)) => ScalarValue::Utf8View(Some(
133+
compute_repeat(s, count, i32::MAX as usize)?,
134+
)),
135+
ScalarValue::Utf8(Some(s)) => ScalarValue::Utf8(Some(
136+
compute_repeat(s, count, i32::MAX as usize)?,
137+
)),
136138
ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some(
137139
compute_repeat(s, count, i64::MAX as usize)?,
138140
)),
@@ -183,26 +185,47 @@ fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
183185
match string_array.data_type() {
184186
Utf8View => {
185187
let string_view_array = string_array.as_string_view();
186-
repeat_impl::<i32, &StringViewArray>(
188+
let (_, max_item_capacity) = calculate_capacities(
187189
&string_view_array,
188190
number_array,
189191
i32::MAX as usize,
192+
)?;
193+
let builder = StringViewBuilder::with_capacity(string_array.len());
194+
repeat_impl::<&StringViewArray, StringViewBuilder>(
195+
&string_view_array,
196+
number_array,
197+
max_item_capacity,
198+
builder,
190199
)
191200
}
192201
Utf8 => {
193202
let string_arr = string_array.as_string::<i32>();
194-
repeat_impl::<i32, &GenericStringArray<i32>>(
203+
let (total_capacity, max_item_capacity) =
204+
calculate_capacities(&string_arr, number_array, i32::MAX as usize)?;
205+
let builder = GenericStringBuilder::<i32>::with_capacity(
206+
string_array.len(),
207+
total_capacity,
208+
);
209+
repeat_impl::<&GenericStringArray<i32>, GenericStringBuilder<i32>>(
195210
&string_arr,
196211
number_array,
197-
i32::MAX as usize,
212+
max_item_capacity,
213+
builder,
198214
)
199215
}
200216
LargeUtf8 => {
201217
let string_arr = string_array.as_string::<i64>();
202-
repeat_impl::<i64, &GenericStringArray<i64>>(
218+
let (total_capacity, max_item_capacity) =
219+
calculate_capacities(&string_arr, number_array, i64::MAX as usize)?;
220+
let builder = GenericStringBuilder::<i64>::with_capacity(
221+
string_array.len(),
222+
total_capacity,
223+
);
224+
repeat_impl::<&GenericStringArray<i64>, GenericStringBuilder<i64>>(
203225
&string_arr,
204226
number_array,
205-
i64::MAX as usize,
227+
max_item_capacity,
228+
builder,
206229
)
207230
}
208231
other => exec_err!(
@@ -212,17 +235,17 @@ fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
212235
}
213236
}
214237

215-
fn repeat_impl<'a, T, S>(
238+
fn calculate_capacities<'a, S>(
216239
string_array: &S,
217240
number_array: &Int64Array,
218241
max_str_len: usize,
219-
) -> Result<ArrayRef>
242+
) -> Result<(usize, usize)>
220243
where
221-
T: OffsetSizeTrait,
222-
S: StringArrayType<'a> + 'a,
244+
S: StringArrayType<'a>,
223245
{
224246
let mut total_capacity = 0;
225247
let mut max_item_capacity = 0;
248+
226249
string_array.iter().zip(number_array.iter()).try_for_each(
227250
|(string, number)| -> Result<(), DataFusionError> {
228251
match (string, number) {
@@ -244,9 +267,19 @@ where
244267
},
245268
)?;
246269

247-
let mut builder =
248-
GenericStringBuilder::<T>::with_capacity(string_array.len(), total_capacity);
270+
Ok((total_capacity, max_item_capacity))
271+
}
249272

273+
fn repeat_impl<'a, S, B>(
274+
string_array: &S,
275+
number_array: &Int64Array,
276+
max_item_capacity: usize,
277+
mut builder: B,
278+
) -> Result<ArrayRef>
279+
where
280+
S: StringArrayType<'a> + 'a,
281+
B: StringLikeArrayBuilder,
282+
{
250283
// Reusable buffer to avoid allocations in string.repeat()
251284
let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);
252285

@@ -303,8 +336,8 @@ where
303336

304337
#[cfg(test)]
305338
mod tests {
306-
use arrow::array::{Array, StringArray};
307-
use arrow::datatypes::DataType::Utf8;
339+
use arrow::array::{Array, LargeStringArray, StringArray, StringViewArray};
340+
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
308341

309342
use datafusion_common::ScalarValue;
310343
use datafusion_common::{Result, exec_err};
@@ -357,8 +390,8 @@ mod tests {
357390
],
358391
Ok(Some("PgPgPgPg")),
359392
&str,
360-
Utf8,
361-
StringArray
393+
Utf8View,
394+
StringViewArray
362395
);
363396
test_function!(
364397
RepeatFunc::new(),
@@ -368,8 +401,19 @@ mod tests {
368401
],
369402
Ok(None),
370403
&str,
371-
Utf8,
372-
StringArray
404+
Utf8View,
405+
StringViewArray
406+
);
407+
test_function!(
408+
RepeatFunc::new(),
409+
vec![
410+
ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("Pg")))),
411+
ColumnarValue::Scalar(ScalarValue::Int64(None)),
412+
],
413+
Ok(None),
414+
&str,
415+
LargeUtf8,
416+
LargeStringArray
373417
);
374418
test_function!(
375419
RepeatFunc::new(),
@@ -379,8 +423,8 @@ mod tests {
379423
],
380424
Ok(None),
381425
&str,
382-
Utf8,
383-
StringArray
426+
Utf8View,
427+
StringViewArray
384428
);
385429
test_function!(
386430
RepeatFunc::new(),

datafusion/sqllogictest/test_files/string/string_literal.slt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,35 @@ SELECT repeat('foo', 3)
347347
----
348348
foofoofoo
349349

350+
query T
351+
SELECT repeat(arrow_cast('foo', 'LargeUtf8'), 3)
352+
----
353+
foofoofoo
354+
355+
query T
356+
SELECT repeat(arrow_cast('foo', 'Utf8View'), 3)
357+
----
358+
foofoofoo
359+
350360
query T
351361
SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3)
352362
----
353363
foofoofoo
354364

365+
query T
366+
SELECT arrow_typeof(repeat('foo', 3))
367+
----
368+
Utf8
369+
370+
query T
371+
SELECT arrow_typeof(repeat(arrow_cast('foo', 'LargeUtf8'), 3))
372+
----
373+
LargeUtf8
374+
375+
query T
376+
SELECT arrow_typeof(repeat(arrow_cast('foo', 'Utf8View'), 3))
377+
----
378+
Utf8View
355379

356380
query T
357381
SELECT replace('foobar', 'bar', 'hello')

0 commit comments

Comments
 (0)