diff --git a/datafusion/spark/src/function/string/make_valid_utf8.rs b/datafusion/spark/src/function/string/make_valid_utf8.rs new file mode 100644 index 0000000000000..08487f8411643 --- /dev/null +++ b/datafusion/spark/src/function/string/make_valid_utf8.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, LargeStringArray, StringArray}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::logical_expr::{ColumnarValue, Signature, Volatility}; +use datafusion_common::cast::{ + as_binary_array, as_binary_view_array, as_large_binary_array, +}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; + +/// Spark-compatible `make_valid_utf8` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkMakeValidUtf8 { + signature: Signature, +} + +impl Default for SparkMakeValidUtf8 { + fn default() -> Self { + Self::new() + } +} + +impl SparkMakeValidUtf8 { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::Utf8, + DataType::LargeUtf8, + DataType::Utf8View, + DataType::Binary, + DataType::BinaryView, + DataType::LargeBinary, + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkMakeValidUtf8 { + fn name(&self) -> &str { + "make_valid_utf8" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [make_valid_utf8] = take_function_args(self.name(), args.arg_fields)?; + let return_type = match make_valid_utf8.data_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + Ok(make_valid_utf8.data_type().clone()) + } + DataType::Binary | DataType::BinaryView => Ok(DataType::Utf8), + DataType::LargeBinary => Ok(DataType::LargeUtf8), + data_type => internal_err!("make_valid_utf8 does not support: {data_type}"), + }?; + Ok(Arc::new(Field::new( + self.name(), + return_type, + make_valid_utf8.is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_make_valid_utf8_inner, vec![])(&args.args) + } +} + +fn spark_make_valid_utf8_inner(args: &[ArrayRef]) -> Result { + let array = &args[0]; + match &array.data_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Ok(array.to_owned()), + DataType::Binary => Ok(Arc::new( + as_binary_array(&array)? + .iter() + .map(|x| x.map(String::from_utf8_lossy)) + .collect::(), + )), + DataType::BinaryView => Ok(Arc::new( + as_binary_view_array(&array)? + .iter() + .map(|x| x.map(String::from_utf8_lossy)) + .collect::(), + )), + DataType::LargeBinary => Ok(Arc::new( + as_large_binary_array(&array)? + .iter() + .map(|x| x.map(String::from_utf8_lossy)) + .collect::(), + )), + data_type => { + internal_err!("make_valid_utf8 does not support: {data_type}") + } + } +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index 7bcdac5d85474..e0f6878fdea7b 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -25,6 +25,7 @@ pub mod ilike; pub mod length; pub mod like; pub mod luhn_check; +pub mod make_valid_utf8; pub mod soundex; pub mod space; pub mod substring; @@ -47,6 +48,7 @@ make_udf_function!(space::SparkSpace, space); make_udf_function!(substring::SparkSubstring, substring); make_udf_function!(base64::SparkUnBase64, unbase64); make_udf_function!(soundex::SparkSoundex, soundex); +make_udf_function!(make_valid_utf8::SparkMakeValidUtf8, make_valid_utf8); pub mod expr_fn { use datafusion_functions::export_functions; @@ -113,6 +115,11 @@ pub mod expr_fn { str )); export_functions!((soundex, "Returns Soundex code of the string.", str)); + export_functions!(( + make_valid_utf8, + "Returns the original string if str is a valid UTF-8 string, otherwise returns a new string whose invalid UTF8 byte sequences are replaced using the UNICODE replacement character U+FFFD.", + str + )); } pub fn functions() -> Vec> { @@ -131,5 +138,6 @@ pub fn functions() -> Vec> { substring(), unbase64(), soundex(), + make_valid_utf8(), ] } diff --git a/datafusion/sqllogictest/test_files/spark/string/make_valid_utf8.slt b/datafusion/sqllogictest/test_files/spark/string/make_valid_utf8.slt new file mode 100644 index 0000000000000..b87e0530e957d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/make_valid_utf8.slt @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT make_valid_utf8('Spark'::string); +---- +Spark + +query T +SELECT make_valid_utf8(''::string); +---- +(empty) + +query T +SELECT make_valid_utf8(NULL::string); +---- +NULL + +query T +SELECT make_valid_utf8(arrow_cast(x'C3A9', 'Binary')); +---- +é + +query T +SELECT make_valid_utf8(arrow_cast(x'F0908C80', 'Binary')); +---- +𐌀 + +query T +SELECT make_valid_utf8(arrow_cast(x'ED9FBF', 'Binary')); +---- +퟿ + +query T +SELECT make_valid_utf8(arrow_cast(x'FF', 'Binary')); +---- +� + +query T +SELECT make_valid_utf8(arrow_cast(x'C0AF', 'Binary')); +---- +�� + +query T +SELECT make_valid_utf8(arrow_cast(x'F4808080', 'Binary')); +---- +􀀀 + +query T +SELECT make_valid_utf8(arrow_cast(x'EDA0BDEDB2A9', 'Binary')); +---- +������ + +query T +SELECT make_valid_utf8(arrow_cast(x'F0', 'Binary')); +---- +� + +query T +SELECT make_valid_utf8(arrow_cast(x'E0', 'Binary')); +---- +� + +query T +SELECT make_valid_utf8(arrow_cast(x'F0808080', 'Binary')); +---- +���� + +query T +SELECT make_valid_utf8(arrow_cast(x'61', 'Binary')); +---- +a + +query T +SELECT make_valid_utf8(arrow_cast(x'61C262', 'Binary')); +---- +a�b