Skip to content

Commit af79d14

Browse files
authored
Make translate emit Utf8View for Utf8View input (#20624)
## Which issue does this PR close? - Part of #20585 <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> ## Rationale for this change String UDFs should preserve string representation where feasible. `translate` previously accepted Utf8View input but emitted Utf8, causing an unnecessary type downgrade. This aligns `translate` with the expected behavior of returning the same string type as its primary input. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? 1. Updated `translate` return type inference to emit Utf8View when input is Utf8View, while preserving existing behavior for Utf8 and LargeUtf8. 2. Refactored `translate` and `translate_with_map` to use explicit string builders (via a local `TranslateOutput` helper trait) instead of `.collect::<GenericStringArray<T>>()`, so the correct output array type is produced for each input type. 3. Added unit tests for Utf8View input (basic, null, non-ASCII) and sqllogictests verifying `arrow_typeof` output for all three string types. <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes. Unit tests and sqllogictests are included. <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? No. <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 39226c3 commit af79d14

2 files changed

Lines changed: 126 additions & 46 deletions

File tree

datafusion/functions/src/unicode/translate.rs

Lines changed: 104 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@
1616
// under the License.
1717

1818
use std::any::Any;
19-
use std::sync::Arc;
2019

2120
use arrow::array::{
22-
ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait,
21+
ArrayAccessor, ArrayIter, ArrayRef, AsArray, LargeStringBuilder, StringBuilder,
22+
StringLikeArrayBuilder, StringViewBuilder,
2323
};
2424
use arrow::datatypes::DataType;
2525
use datafusion_common::HashMap;
2626
use unicode_segmentation::UnicodeSegmentation;
2727

28-
use crate::utils::{make_scalar_function, utf8_to_str_type};
28+
use crate::utils::make_scalar_function;
2929
use datafusion_common::{Result, exec_err};
3030
use datafusion_expr::TypeSignature::Exact;
3131
use datafusion_expr::{
@@ -93,7 +93,7 @@ impl ScalarUDFImpl for TranslateFunc {
9393
}
9494

9595
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
96-
utf8_to_str_type(&arg_types[0], "translate")
96+
Ok(arg_types[0].clone())
9797
}
9898

9999
fn invoke_with_args(
@@ -116,33 +116,42 @@ impl ScalarUDFImpl for TranslateFunc {
116116
let ascii_table = build_ascii_translate_table(from_str, to_str);
117117

118118
let string_array = args.args[0].to_array_of_size(args.number_rows)?;
119+
let len = string_array.len();
119120

120121
let result = match string_array.data_type() {
121122
DataType::Utf8View => {
122123
let arr = string_array.as_string_view();
123-
translate_with_map::<i32, _>(
124+
let builder = StringViewBuilder::with_capacity(len);
125+
translate_with_map(
124126
arr,
125127
&from_map,
126128
&to_graphemes,
127129
ascii_table.as_ref(),
130+
builder,
128131
)
129132
}
130133
DataType::Utf8 => {
131134
let arr = string_array.as_string::<i32>();
132-
translate_with_map::<i32, _>(
135+
let builder =
136+
StringBuilder::with_capacity(len, arr.value_data().len());
137+
translate_with_map(
133138
arr,
134139
&from_map,
135140
&to_graphemes,
136141
ascii_table.as_ref(),
142+
builder,
137143
)
138144
}
139145
DataType::LargeUtf8 => {
140146
let arr = string_array.as_string::<i64>();
141-
translate_with_map::<i64, _>(
147+
let builder =
148+
LargeStringBuilder::with_capacity(len, arr.value_data().len());
149+
translate_with_map(
142150
arr,
143151
&from_map,
144152
&to_graphemes,
145153
ascii_table.as_ref(),
154+
builder,
146155
)
147156
}
148157
other => {
@@ -172,24 +181,30 @@ fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> {
172181
}
173182

174183
fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
184+
let len = args[0].len();
175185
match args[0].data_type() {
176186
DataType::Utf8View => {
177187
let string_array = args[0].as_string_view();
178188
let from_array = args[1].as_string::<i32>();
179189
let to_array = args[2].as_string::<i32>();
180-
translate::<i32, _, _>(string_array, from_array, to_array)
190+
let builder = StringViewBuilder::with_capacity(len);
191+
translate(string_array, from_array, to_array, builder)
181192
}
182193
DataType::Utf8 => {
183194
let string_array = args[0].as_string::<i32>();
184195
let from_array = args[1].as_string::<i32>();
185196
let to_array = args[2].as_string::<i32>();
186-
translate::<i32, _, _>(string_array, from_array, to_array)
197+
let builder =
198+
StringBuilder::with_capacity(len, string_array.value_data().len());
199+
translate(string_array, from_array, to_array, builder)
187200
}
188201
DataType::LargeUtf8 => {
189202
let string_array = args[0].as_string::<i64>();
190203
let from_array = args[1].as_string::<i32>();
191204
let to_array = args[2].as_string::<i32>();
192-
translate::<i64, _, _>(string_array, from_array, to_array)
205+
let builder =
206+
LargeStringBuilder::with_capacity(len, string_array.value_data().len());
207+
translate(string_array, from_array, to_array, builder)
193208
}
194209
other => {
195210
exec_err!("Unsupported data type {other:?} for function translate")
@@ -199,14 +214,16 @@ fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
199214

200215
/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.
201216
/// translate('12345', '143', 'ax') = 'a2x5'
202-
fn translate<'a, T: OffsetSizeTrait, V, B>(
217+
fn translate<'a, V, B, O>(
203218
string_array: V,
204219
from_array: B,
205220
to_array: B,
221+
mut builder: O,
206222
) -> Result<ArrayRef>
207223
where
208224
V: ArrayAccessor<Item = &'a str>,
209225
B: ArrayAccessor<Item = &'a str>,
226+
O: StringLikeArrayBuilder,
210227
{
211228
let string_array_iter = ArrayIter::new(string_array);
212229
let from_array_iter = ArrayIter::new(from_array);
@@ -219,10 +236,9 @@ where
219236
let mut string_graphemes: Vec<&str> = Vec::new();
220237
let mut result_graphemes: Vec<&str> = Vec::new();
221238

222-
let result = string_array_iter
223-
.zip(from_array_iter)
224-
.zip(to_array_iter)
225-
.map(|((string, from), to)| match (string, from, to) {
239+
for ((string, from), to) in string_array_iter.zip(from_array_iter).zip(to_array_iter)
240+
{
241+
match (string, from, to) {
226242
(Some(string), Some(from), Some(to)) => {
227243
// Clear and reuse buffers
228244
from_map.clear();
@@ -254,13 +270,13 @@ where
254270
}
255271
}
256272

257-
Some(result_graphemes.concat())
273+
builder.append_value(&result_graphemes.concat());
258274
}
259-
_ => None,
260-
})
261-
.collect::<GenericStringArray<T>>();
275+
_ => builder.append_null(),
276+
}
277+
}
262278

263-
Ok(Arc::new(result) as ArrayRef)
279+
Ok(builder.finish())
264280
}
265281

266282
/// Sentinel value in the ASCII translate table indicating the character should
@@ -300,21 +316,23 @@ fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> {
300316
/// translation map instead of rebuilding it for every row. When an ASCII byte
301317
/// lookup table is provided, ASCII input rows use the lookup table; non-ASCII
302318
/// inputs fallback to using the map.
303-
fn translate_with_map<'a, T: OffsetSizeTrait, V>(
319+
fn translate_with_map<'a, V, O>(
304320
string_array: V,
305321
from_map: &HashMap<&str, usize>,
306322
to_graphemes: &[&str],
307323
ascii_table: Option<&[u8; 128]>,
324+
mut builder: O,
308325
) -> Result<ArrayRef>
309326
where
310327
V: ArrayAccessor<Item = &'a str>,
328+
O: StringLikeArrayBuilder,
311329
{
312330
let mut result_graphemes: Vec<&str> = Vec::new();
313331
let mut ascii_buf: Vec<u8> = Vec::new();
314332

315-
let result = ArrayIter::new(string_array)
316-
.map(|string| {
317-
string.map(|s| {
333+
for string in ArrayIter::new(string_array) {
334+
match string {
335+
Some(s) => {
318336
// Fast path: byte-level table lookup for ASCII strings
319337
if let Some(table) = ascii_table
320338
&& s.is_ascii()
@@ -327,37 +345,38 @@ where
327345
}
328346
}
329347
// SAFETY: all bytes are ASCII, hence valid UTF-8.
330-
return unsafe {
331-
std::str::from_utf8_unchecked(&ascii_buf).to_owned()
332-
};
333-
}
334-
335-
// Slow path: grapheme-based translation
336-
result_graphemes.clear();
337-
338-
for c in s.graphemes(true) {
339-
match from_map.get(c) {
340-
Some(n) => {
341-
if let Some(replacement) = to_graphemes.get(*n) {
342-
result_graphemes.push(*replacement);
348+
builder.append_value(unsafe {
349+
std::str::from_utf8_unchecked(&ascii_buf)
350+
});
351+
} else {
352+
// Slow path: grapheme-based translation
353+
result_graphemes.clear();
354+
355+
for c in s.graphemes(true) {
356+
match from_map.get(c) {
357+
Some(n) => {
358+
if let Some(replacement) = to_graphemes.get(*n) {
359+
result_graphemes.push(*replacement);
360+
}
343361
}
362+
None => result_graphemes.push(c),
344363
}
345-
None => result_graphemes.push(c),
346364
}
347-
}
348365

349-
result_graphemes.concat()
350-
})
351-
})
352-
.collect::<GenericStringArray<T>>();
366+
builder.append_value(&result_graphemes.concat());
367+
}
368+
}
369+
None => builder.append_null(),
370+
}
371+
}
353372

354-
Ok(Arc::new(result) as ArrayRef)
373+
Ok(builder.finish())
355374
}
356375

357376
#[cfg(test)]
358377
mod tests {
359-
use arrow::array::{Array, StringArray};
360-
use arrow::datatypes::DataType::Utf8;
378+
use arrow::array::{Array, StringArray, StringViewArray};
379+
use arrow::datatypes::DataType::{Utf8, Utf8View};
361380

362381
use datafusion_common::{Result, ScalarValue};
363382
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
@@ -453,6 +472,45 @@ mod tests {
453472
Utf8,
454473
StringArray
455474
);
475+
// Utf8View input should produce Utf8View output
476+
test_function!(
477+
TranslateFunc::new(),
478+
vec![
479+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("12345".into()))),
480+
ColumnarValue::Scalar(ScalarValue::from("143")),
481+
ColumnarValue::Scalar(ScalarValue::from("ax"))
482+
],
483+
Ok(Some("a2x5")),
484+
&str,
485+
Utf8View,
486+
StringViewArray
487+
);
488+
// Null Utf8View input
489+
test_function!(
490+
TranslateFunc::new(),
491+
vec![
492+
ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
493+
ColumnarValue::Scalar(ScalarValue::from("143")),
494+
ColumnarValue::Scalar(ScalarValue::from("ax"))
495+
],
496+
Ok(None),
497+
&str,
498+
Utf8View,
499+
StringViewArray
500+
);
501+
// Non-ASCII Utf8View input
502+
test_function!(
503+
TranslateFunc::new(),
504+
vec![
505+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("é2íñ5".into()))),
506+
ColumnarValue::Scalar(ScalarValue::from("éñí")),
507+
ColumnarValue::Scalar(ScalarValue::from("óü"))
508+
],
509+
Ok(Some("ó2ü5")),
510+
&str,
511+
Utf8View,
512+
StringViewArray
513+
);
456514

457515
#[cfg(not(feature = "unicode_expressions"))]
458516
test_function!(

datafusion/sqllogictest/test_files/string/string_literal.slt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,3 +1783,25 @@ SELECT
17831783
;
17841784
----
17851785
48 176 32 40
1786+
1787+
# translate preserves input string type
1788+
1789+
query T
1790+
SELECT translate(arrow_cast('12345', 'Utf8View'), '143', 'ax')
1791+
----
1792+
a2x5
1793+
1794+
query T
1795+
SELECT arrow_typeof(translate('12345', '143', 'ax'))
1796+
----
1797+
Utf8
1798+
1799+
query T
1800+
SELECT arrow_typeof(translate(arrow_cast('12345', 'LargeUtf8'), '143', 'ax'))
1801+
----
1802+
LargeUtf8
1803+
1804+
query T
1805+
SELECT arrow_typeof(translate(arrow_cast('12345', 'Utf8View'), '143', 'ax'))
1806+
----
1807+
Utf8View

0 commit comments

Comments
 (0)