Skip to content

Commit d8c9797

Browse files
kazantsev-maksimKazantsev Maksim
andauthored
Spark is_valid_utf8 function implementation (#21627)
## Which issue does this PR close? N/A ## Rationale for this change Add new spark function: https://spark.apache.org/docs/latest/api/sql/index.html#is_valid_utf8 ## What changes are included in this PR? - Implementation - SLT tests ## Are these changes tested? Yes, tests added as part of this PR. ## Are there any user-facing changes? No, these are new function. --------- Co-authored-by: Kazantsev Maksim <mn.kazantsev@gmail.com>
1 parent 6aa5a7e commit d8c9797

3 files changed

Lines changed: 329 additions & 0 deletions

File tree

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::datatypes::{DataType, Field, FieldRef};
19+
use datafusion::logical_expr::{ColumnarValue, Signature, Volatility};
20+
use datafusion_common::{Result, internal_err};
21+
use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
22+
23+
use arrow::array::{Array, ArrayRef, BooleanArray};
24+
use arrow::buffer::BooleanBuffer;
25+
use datafusion_common::cast::{
26+
as_binary_array, as_binary_view_array, as_large_binary_array,
27+
};
28+
use datafusion_common::utils::take_function_args;
29+
use datafusion_functions::utils::make_scalar_function;
30+
31+
use std::sync::Arc;
32+
33+
/// Spark-compatible `is_valid_utf8` expression
34+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#is_valid_utf8>
35+
#[derive(Debug, PartialEq, Eq, Hash)]
36+
pub struct SparkIsValidUtf8 {
37+
signature: Signature,
38+
}
39+
40+
impl Default for SparkIsValidUtf8 {
41+
fn default() -> Self {
42+
Self::new()
43+
}
44+
}
45+
46+
impl SparkIsValidUtf8 {
47+
pub fn new() -> Self {
48+
Self {
49+
signature: Signature::uniform(
50+
1,
51+
vec![
52+
DataType::Utf8,
53+
DataType::LargeUtf8,
54+
DataType::Utf8View,
55+
DataType::Binary,
56+
DataType::BinaryView,
57+
DataType::LargeBinary,
58+
],
59+
Volatility::Immutable,
60+
),
61+
}
62+
}
63+
}
64+
65+
impl ScalarUDFImpl for SparkIsValidUtf8 {
66+
fn name(&self) -> &str {
67+
"is_valid_utf8"
68+
}
69+
70+
fn signature(&self) -> &Signature {
71+
&self.signature
72+
}
73+
74+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75+
internal_err!("return_field_from_args should be used instead")
76+
}
77+
78+
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
79+
Ok(Arc::new(Field::new(self.name(), DataType::Boolean, true)))
80+
}
81+
82+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
83+
make_scalar_function(spark_is_valid_utf8_inner, vec![])(&args.args)
84+
}
85+
}
86+
87+
fn spark_is_valid_utf8_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
88+
let [array] = take_function_args("is_valid_utf8", args)?;
89+
match array.data_type() {
90+
DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => {
91+
Ok(Arc::new(BooleanArray::new(
92+
BooleanBuffer::new_set(array.len()),
93+
array.nulls().cloned(),
94+
)))
95+
}
96+
DataType::Binary => Ok(Arc::new(
97+
as_binary_array(array)?
98+
.iter()
99+
.map(|x| x.map(|y| str::from_utf8(y).is_ok()))
100+
.collect::<BooleanArray>(),
101+
)),
102+
DataType::LargeBinary => Ok(Arc::new(
103+
as_large_binary_array(array)?
104+
.iter()
105+
.map(|x| x.map(|y| str::from_utf8(y).is_ok()))
106+
.collect::<BooleanArray>(),
107+
)),
108+
DataType::BinaryView => Ok(Arc::new(
109+
as_binary_view_array(array)?
110+
.iter()
111+
.map(|x| x.map(|y| str::from_utf8(y).is_ok()))
112+
.collect::<BooleanArray>(),
113+
)),
114+
data_type => {
115+
internal_err!("is_valid_utf8 does not support: {data_type}")
116+
}
117+
}
118+
}

datafusion/spark/src/function/string/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub mod concat;
2222
pub mod elt;
2323
pub mod format_string;
2424
pub mod ilike;
25+
pub mod is_valid_utf8;
2526
pub mod length;
2627
pub mod like;
2728
pub mod luhn_check;
@@ -49,6 +50,7 @@ make_udf_function!(substring::SparkSubstring, substring);
4950
make_udf_function!(base64::SparkUnBase64, unbase64);
5051
make_udf_function!(soundex::SparkSoundex, soundex);
5152
make_udf_function!(make_valid_utf8::SparkMakeValidUtf8, make_valid_utf8);
53+
make_udf_function!(is_valid_utf8::SparkIsValidUtf8, is_valid_utf8);
5254

5355
pub mod expr_fn {
5456
use datafusion_functions::export_functions;
@@ -115,6 +117,11 @@ pub mod expr_fn {
115117
str
116118
));
117119
export_functions!((soundex, "Returns Soundex code of the string.", str));
120+
export_functions!((
121+
is_valid_utf8,
122+
"Returns true if str is a valid UTF-8 string, otherwise returns false",
123+
str
124+
));
118125
export_functions!((
119126
make_valid_utf8,
120127
"Returns the original string if str is a valid UTF-8 string, otherwise returns a new string whose invalid UTF8 byte sequences are replaced using the UNICODE replacement character U+FFFD.",
@@ -139,5 +146,6 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
139146
unbase64(),
140147
soundex(),
141148
make_valid_utf8(),
149+
is_valid_utf8(),
142150
]
143151
}
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
statement ok
19+
CREATE TABLE test_is_valid_utf8(value STRING) AS VALUES
20+
(arrow_cast('Hello, world!', 'Utf8')),
21+
(arrow_cast('Spark', 'Utf8')),
22+
(arrow_cast('DataFusion', 'Utf8')),
23+
(arrow_cast('ASCII only 123 !@#', 'Utf8')),
24+
(arrow_cast(NULL, 'Utf8'));
25+
26+
query B
27+
SELECT is_valid_utf8(value) FROM test_is_valid_utf8;
28+
----
29+
true
30+
true
31+
true
32+
true
33+
NULL
34+
35+
query B
36+
SELECT is_valid_utf8(NULL::string);
37+
----
38+
NULL
39+
40+
query B
41+
SELECT is_valid_utf8('Hello, world!'::string);
42+
----
43+
true
44+
45+
query B
46+
SELECT is_valid_utf8('😀🎉✨'::string);
47+
----
48+
true
49+
50+
query B
51+
SELECT is_valid_utf8(''::string);
52+
----
53+
true
54+
55+
query B
56+
SELECT is_valid_utf8('ASCII only 123 !@#'::string);
57+
----
58+
true
59+
60+
query B
61+
SELECT is_valid_utf8(arrow_cast(x'C2A9', 'Binary'));
62+
----
63+
true
64+
65+
query B
66+
SELECT is_valid_utf8(arrow_cast(x'C2AE', 'Binary'));
67+
----
68+
true
69+
70+
query B
71+
SELECT is_valid_utf8(arrow_cast(x'E282AC', 'Binary'));
72+
----
73+
true
74+
75+
query B
76+
SELECT is_valid_utf8(arrow_cast(x'E284A2', 'Binary'));
77+
----
78+
true
79+
80+
query B
81+
SELECT is_valid_utf8(arrow_cast(x'F09F9880', 'Binary'));
82+
----
83+
true
84+
85+
query B
86+
SELECT is_valid_utf8(arrow_cast(x'F09F8E89', 'Binary'));
87+
----
88+
true
89+
90+
query B
91+
SELECT is_valid_utf8(arrow_cast(x'80', 'Binary'));
92+
----
93+
false
94+
95+
query B
96+
SELECT is_valid_utf8(arrow_cast(x'BF', 'Binary'));
97+
----
98+
false
99+
100+
query B
101+
SELECT is_valid_utf8(arrow_cast(x'808080', 'Binary'));
102+
----
103+
false
104+
105+
query B
106+
SELECT is_valid_utf8(arrow_cast(x'C2', 'Binary'));
107+
----
108+
false
109+
110+
query B
111+
SELECT is_valid_utf8(arrow_cast(x'E2', 'Binary'));
112+
----
113+
false
114+
115+
query B
116+
SELECT is_valid_utf8(arrow_cast(x'F0', 'Binary'));
117+
----
118+
false
119+
120+
query B
121+
SELECT is_valid_utf8(arrow_cast(x'E282', 'Binary'));
122+
----
123+
false
124+
125+
query B
126+
SELECT is_valid_utf8(arrow_cast(x'C081', 'Binary'));
127+
----
128+
false
129+
130+
query B
131+
SELECT is_valid_utf8(arrow_cast(x'E08080', 'Binary'));
132+
----
133+
false
134+
135+
query B
136+
SELECT is_valid_utf8(arrow_cast(x'F0808080', 'Binary'));
137+
----
138+
false
139+
140+
query B
141+
SELECT is_valid_utf8(arrow_cast(x'FE', 'Binary'));
142+
----
143+
false
144+
145+
query B
146+
SELECT is_valid_utf8(arrow_cast(x'FF', 'Binary'));
147+
----
148+
false
149+
150+
query B
151+
SELECT is_valid_utf8(arrow_cast(x'61C262', 'Binary'));
152+
----
153+
false
154+
155+
query B
156+
SELECT is_valid_utf8(arrow_cast(x'41BF42', 'Binary'));
157+
----
158+
false
159+
160+
query B
161+
SELECT is_valid_utf8(arrow_cast(x'ED9FBF', 'Binary'));
162+
----
163+
true
164+
165+
query B
166+
SELECT is_valid_utf8(arrow_cast(x'EDA080', 'Binary'));
167+
----
168+
false
169+
170+
query B
171+
SELECT is_valid_utf8(arrow_cast(x'EDBFBF', 'Binary'));
172+
----
173+
false
174+
175+
query B
176+
SELECT is_valid_utf8(arrow_cast(x'F48FBFBF', 'Binary'));
177+
----
178+
true
179+
180+
query B
181+
SELECT is_valid_utf8(arrow_cast(x'F4908080', 'Binary'));
182+
----
183+
false
184+
185+
query B
186+
SELECT is_valid_utf8(arrow_cast(x'6162C2A963', 'Binary'));
187+
----
188+
true
189+
190+
query B
191+
SELECT is_valid_utf8(arrow_cast(x'6162806364', 'Binary'));
192+
----
193+
false
194+
195+
query B
196+
SELECT is_valid_utf8(arrow_cast(x'610062', 'Binary'));
197+
----
198+
true
199+
200+
query B
201+
SELECT is_valid_utf8(arrow_cast(x'', 'Binary'));
202+
----
203+
true

0 commit comments

Comments
 (0)