Skip to content

Commit b51edff

Browse files
authored
Update reverse UDF to emit utf8view when input is utf8view (#20604)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> Part of #20585 ## Rationale for this change Functions ideally should emit strings in the same format as the input and previously the reverse function was emitting using utf8 for input that was in utf8view. ## What changes are included in this PR? Code, tests. ## Are these changes tested? Yes. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 33b9afa commit b51edff

2 files changed

Lines changed: 56 additions & 23 deletions

File tree

datafusion/functions/src/unicode/reverse.rs

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use crate::utils::{make_scalar_function, utf8_to_str_type};
21+
use crate::utils::make_scalar_function;
2222
use DataType::{LargeUtf8, Utf8, Utf8View};
2323
use arrow::array::{
24-
Array, ArrayRef, AsArray, GenericStringBuilder, OffsetSizeTrait, StringArrayType,
24+
Array, ArrayRef, AsArray, LargeStringBuilder, StringArrayType, StringBuilder,
25+
StringLikeArrayBuilder, StringViewBuilder,
2526
};
2627
use arrow::datatypes::DataType;
2728
use datafusion_common::{Result, exec_err};
@@ -82,7 +83,7 @@ impl ScalarUDFImpl for ReverseFunc {
8283
}
8384

8485
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
85-
utf8_to_str_type(&arg_types[0], "reverse")
86+
Ok(arg_types[0].clone())
8687
}
8788

8889
fn invoke_with_args(
@@ -91,8 +92,7 @@ impl ScalarUDFImpl for ReverseFunc {
9192
) -> Result<ColumnarValue> {
9293
let args = &args.args;
9394
match args[0].data_type() {
94-
Utf8 | Utf8View => make_scalar_function(reverse::<i32>, vec![])(args),
95-
LargeUtf8 => make_scalar_function(reverse::<i64>, vec![])(args),
95+
Utf8 | Utf8View | LargeUtf8 => make_scalar_function(reverse, vec![])(args),
9696
other => {
9797
exec_err!("Unsupported data type {other:?} for function reverse")
9898
}
@@ -106,21 +106,39 @@ impl ScalarUDFImpl for ReverseFunc {
106106

107107
/// Reverses the order of the characters in the string `reverse('abcde') = 'edcba'`.
108108
/// The implementation uses UTF-8 code points as characters
109-
fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
110-
if args[0].data_type() == &Utf8View {
111-
reverse_impl::<T, _>(&args[0].as_string_view())
112-
} else {
113-
reverse_impl::<T, _>(&args[0].as_string::<T>())
109+
fn reverse(args: &[ArrayRef]) -> Result<ArrayRef> {
110+
let len = args[0].len();
111+
112+
match args[0].data_type() {
113+
Utf8 => reverse_impl(
114+
&args[0].as_string::<i32>(),
115+
StringBuilder::with_capacity(len, 1024),
116+
),
117+
Utf8View => reverse_impl(
118+
&args[0].as_string_view(),
119+
StringViewBuilder::with_capacity(len),
120+
),
121+
LargeUtf8 => reverse_impl(
122+
&args[0].as_string::<i64>(),
123+
LargeStringBuilder::with_capacity(len, 1024),
124+
),
125+
_ => unreachable!(
126+
"Reverse can only be applied to Utf8View, Utf8 and LargeUtf8 types"
127+
),
114128
}
115129
}
116130

117-
fn reverse_impl<'a, T: OffsetSizeTrait, V: StringArrayType<'a>>(
118-
string_array: &V,
119-
) -> Result<ArrayRef> {
120-
let mut builder = GenericStringBuilder::<T>::with_capacity(string_array.len(), 1024);
121-
131+
fn reverse_impl<'a, StringArrType, StringBuilderType>(
132+
string_array: &StringArrType,
133+
mut array_builder: StringBuilderType,
134+
) -> Result<ArrayRef>
135+
where
136+
StringArrType: StringArrayType<'a>,
137+
StringBuilderType: StringLikeArrayBuilder,
138+
{
122139
let mut string_buf = String::new();
123140
let mut byte_buf = Vec::<u8>::new();
141+
124142
for string in string_array.iter() {
125143
if let Some(s) = string {
126144
if s.is_ascii() {
@@ -129,25 +147,25 @@ fn reverse_impl<'a, T: OffsetSizeTrait, V: StringArrayType<'a>>(
129147
byte_buf.reverse();
130148
// SAFETY: Since the original string was ASCII, reversing the bytes still results in valid UTF-8.
131149
let reversed = unsafe { std::str::from_utf8_unchecked(&byte_buf) };
132-
builder.append_value(reversed);
150+
array_builder.append_value(reversed);
133151
byte_buf.clear();
134152
} else {
135153
string_buf.extend(s.chars().rev());
136-
builder.append_value(&string_buf);
154+
array_builder.append_value(&string_buf);
137155
string_buf.clear();
138156
}
139157
} else {
140-
builder.append_null();
158+
array_builder.append_null();
141159
}
142160
}
143161

144-
Ok(Arc::new(builder.finish()) as ArrayRef)
162+
Ok(Arc::new(array_builder.finish()) as ArrayRef)
145163
}
146164

147165
#[cfg(test)]
148166
mod tests {
149-
use arrow::array::{Array, LargeStringArray, StringArray};
150-
use arrow::datatypes::DataType::{LargeUtf8, Utf8};
167+
use arrow::array::{Array, LargeStringArray, StringArray, StringViewArray};
168+
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
151169

152170
use datafusion_common::{Result, ScalarValue};
153171
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
@@ -180,8 +198,8 @@ mod tests {
180198
vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
181199
$EXPECTED,
182200
&str,
183-
Utf8,
184-
StringArray
201+
Utf8View,
202+
StringViewArray
185203
);
186204
};
187205
}

datafusion/sqllogictest/test_files/string/string_literal.slt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,21 @@ SELECT reverse(arrow_cast('abcde', 'Utf8View'))
389389
----
390390
edcba
391391

392+
query T
393+
SELECT arrow_typeof(reverse('abcde'))
394+
----
395+
Utf8
396+
397+
query T
398+
SELECT arrow_typeof(reverse(arrow_cast('abcde', 'LargeUtf8')))
399+
----
400+
LargeUtf8
401+
402+
query T
403+
SELECT arrow_typeof(reverse(arrow_cast('abcde', 'Utf8View')))
404+
----
405+
Utf8View
406+
392407
query T
393408
SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)'))
394409
----

0 commit comments

Comments
 (0)