@@ -24,8 +24,10 @@ use crate::string::concat;
2424use crate :: strings:: {
2525 ColumnarValueRef , LargeStringArrayBuilder , StringArrayBuilder , StringViewArrayBuilder ,
2626} ;
27- use datafusion_common:: cast:: { as_string_array, as_string_view_array} ;
28- use datafusion_common:: { Result , ScalarValue , internal_err, plan_err} ;
27+ use datafusion_common:: cast:: { as_binary_array, as_string_array, as_string_view_array} ;
28+ use datafusion_common:: {
29+ Result , ScalarValue , exec_datafusion_err, internal_err, plan_err,
30+ } ;
2931use datafusion_expr:: expr:: ScalarFunction ;
3032use datafusion_expr:: simplify:: { ExprSimplifyResult , SimplifyContext } ;
3133use datafusion_expr:: { ColumnarValue , Documentation , Expr , Volatility , lit} ;
@@ -67,13 +69,24 @@ impl ConcatFunc {
6769 use DataType :: * ;
6870 Self {
6971 signature : Signature :: variadic (
70- vec ! [ Utf8View , Utf8 , LargeUtf8 ] ,
72+ vec ! [ Utf8View , Utf8 , LargeUtf8 , Binary ] ,
7173 Volatility :: Immutable ,
7274 ) ,
7375 }
7476 }
7577}
7678
79+ fn deduce_return_type ( arg_types : & [ DataType ] ) -> DataType {
80+ use DataType :: * ;
81+ if arg_types. contains ( & Utf8View ) {
82+ Utf8View
83+ } else if arg_types. contains ( & LargeUtf8 ) {
84+ LargeUtf8
85+ } else {
86+ Utf8
87+ }
88+ }
89+
7790impl ScalarUDFImpl for ConcatFunc {
7891 fn name ( & self ) -> & str {
7992 "concat"
@@ -87,29 +100,16 @@ impl ScalarUDFImpl for ConcatFunc {
87100 /// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid
88101 /// potential overflow on LargeUtf8 input.
89102 fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
90- use DataType :: * ;
91- if arg_types. contains ( & Utf8View ) {
92- Ok ( Utf8View )
93- } else if arg_types. contains ( & LargeUtf8 ) {
94- Ok ( LargeUtf8 )
95- } else {
96- Ok ( Utf8 )
97- }
103+ Ok ( deduce_return_type ( arg_types) )
98104 }
99105
100106 /// Concatenates the text representations of all the arguments. NULL arguments are ignored.
101107 /// concat('abcde', 2, NULL, 22) = 'abcde222'
102108 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
103109 let ScalarFunctionArgs { args, .. } = args;
104110
105- let return_datatype = if args. iter ( ) . any ( |c| c. data_type ( ) == DataType :: Utf8View )
106- {
107- DataType :: Utf8View
108- } else if args. iter ( ) . any ( |c| c. data_type ( ) == DataType :: LargeUtf8 ) {
109- DataType :: LargeUtf8
110- } else {
111- DataType :: Utf8
112- } ;
111+ let arg_types: Vec < DataType > = args. iter ( ) . map ( |c| c. data_type ( ) ) . collect ( ) ;
112+ let return_datatype = deduce_return_type ( & arg_types) ;
113113
114114 let array_len = args. iter ( ) . find_map ( |x| match x {
115115 ColumnarValue :: Array ( array) => Some ( array. len ( ) ) ,
@@ -118,22 +118,28 @@ impl ScalarUDFImpl for ConcatFunc {
118118
119119 // Scalar
120120 if array_len. is_none ( ) {
121- let mut values = Vec :: with_capacity ( args. len ( ) ) ;
121+ let mut values: Vec < & [ u8 ] > = Vec :: with_capacity ( args. len ( ) ) ;
122122 for arg in & args {
123123 let ColumnarValue :: Scalar ( scalar) = arg else {
124124 return internal_err ! ( "concat expected scalar value, got {arg:?}" ) ;
125125 } ;
126-
127- match scalar. try_as_str ( ) {
128- Some ( Some ( v) ) => values. push ( v) ,
129- Some ( None ) => { } // null literal
130- None => plan_err ! (
131- "Concat function does not support scalar type {}" ,
132- scalar
133- ) ?,
126+ if let ScalarValue :: Binary ( Some ( value) ) = scalar {
127+ values. push ( value) ;
128+ } else {
129+ match scalar. try_as_str ( ) {
130+ Some ( Some ( v) ) => values. push ( v. as_bytes ( ) ) ,
131+ Some ( None ) => { } // null literal
132+ None => plan_err ! (
133+ "Concat function does not support scalar type {}" ,
134+ scalar
135+ ) ?,
136+ }
134137 }
135138 }
136- let result = values. concat ( ) ;
139+ let concat_bytes = values. concat ( ) ;
140+ let result = std:: str:: from_utf8 ( & concat_bytes)
141+ . map_err ( |_| exec_datafusion_err ! ( "invalid UTF-8 in binary literal" ) ) ?
142+ . to_string ( ) ;
137143
138144 return match return_datatype {
139145 DataType :: Utf8View => {
@@ -166,6 +172,13 @@ impl ScalarUDFImpl for ConcatFunc {
166172 columns. push ( ColumnarValueRef :: Scalar ( s. as_bytes ( ) ) ) ;
167173 }
168174 }
175+ ColumnarValue :: Scalar ( ScalarValue :: Binary ( maybe_value) ) => {
176+ if let Some ( b) = maybe_value {
177+ // data_size is a capacity hint, so doesn't matter if it is chars or bytes
178+ data_size += b. len ( ) * len;
179+ columns. push ( ColumnarValueRef :: Scalar ( b. as_slice ( ) ) ) ;
180+ }
181+ }
169182 ColumnarValue :: Array ( array) => {
170183 match array. data_type ( ) {
171184 DataType :: Utf8 => {
@@ -205,6 +218,17 @@ impl ScalarUDFImpl for ConcatFunc {
205218 } ;
206219 columns. push ( column) ;
207220 }
221+ DataType :: Binary => {
222+ let string_array = as_binary_array ( array) ?;
223+
224+ data_size += string_array. values ( ) . len ( ) ;
225+ let column = if array. is_nullable ( ) {
226+ ColumnarValueRef :: NullableBinaryArray ( string_array)
227+ } else {
228+ ColumnarValueRef :: NonNullableBinaryArray ( string_array)
229+ } ;
230+ columns. push ( column) ;
231+ }
208232 other => {
209233 return plan_err ! (
210234 "Input was {other} which is not a supported datatype for concat function"
@@ -226,7 +250,7 @@ impl ScalarUDFImpl for ConcatFunc {
226250 builder. append_offset ( ) ;
227251 }
228252
229- let string_array = builder. finish ( None ) ;
253+ let string_array = builder. finish ( None ) ? ;
230254 Ok ( ColumnarValue :: Array ( Arc :: new ( string_array) ) )
231255 }
232256 DataType :: Utf8View => {
@@ -235,10 +259,10 @@ impl ScalarUDFImpl for ConcatFunc {
235259 columns
236260 . iter ( )
237261 . for_each ( |column| builder. write :: < true > ( column, i) ) ;
238- builder. append_offset ( ) ;
262+ builder. append_offset ( ) ? ;
239263 }
240264
241- let string_array = builder. finish ( None ) ;
265+ let string_array = builder. finish ( None ) ? ;
242266 Ok ( ColumnarValue :: Array ( Arc :: new ( string_array) ) )
243267 }
244268 DataType :: LargeUtf8 => {
@@ -250,7 +274,7 @@ impl ScalarUDFImpl for ConcatFunc {
250274 builder. append_offset ( ) ;
251275 }
252276
253- let string_array = builder. finish ( None ) ;
277+ let string_array = builder. finish ( None ) ? ;
254278 Ok ( ColumnarValue :: Array ( Arc :: new ( string_array) ) )
255279 }
256280 _ => unreachable ! ( ) ,
@@ -446,7 +470,33 @@ mod tests {
446470 Utf8View ,
447471 StringViewArray
448472 ) ;
449-
473+ test_function ! (
474+ ConcatFunc :: new( ) ,
475+ vec![
476+ ColumnarValue :: Scalar ( ScalarValue :: Binary ( Some (
477+ "Café" . as_bytes( ) . into( )
478+ ) ) ) ,
479+ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) ,
480+ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( "cc" . to_string( ) ) ) ) ,
481+ ] ,
482+ Ok ( Some ( "Cafécc" ) ) ,
483+ & str ,
484+ Utf8 ,
485+ StringArray
486+ ) ;
487+ test_function ! (
488+ ConcatFunc :: new( ) ,
489+ vec![
490+ ColumnarValue :: Scalar ( ScalarValue :: Binary ( Some ( Vec :: from(
491+ "Café" . as_bytes( )
492+ ) ) ) ) ,
493+ ColumnarValue :: Scalar ( ScalarValue :: Binary ( Some ( "cc" . as_bytes( ) . into( ) ) ) ) ,
494+ ] ,
495+ Ok ( Some ( "Cafécc" ) ) ,
496+ & str ,
497+ Utf8 ,
498+ StringArray
499+ ) ;
450500 Ok ( ( ) )
451501 }
452502
0 commit comments