Skip to content

Commit 7f29cb0

Browse files
adriangbclaude
andauthored
Add arrow_try_cast UDF (#21130)
## Which issue does this PR close? N/A - new feature ## Rationale for this change `arrow_cast(expr, 'DataType')` casts to Arrow data types specified as strings but errors on failure. `try_cast(expr AS type)` returns NULL on failure but only works with SQL types. There's currently no way to attempt a cast to a specific Arrow type and get NULL on failure instead of an error. ## What changes are included in this PR? Adds a new `arrow_try_cast(expression, datatype)` scalar function that combines the behavior of `arrow_cast` and `try_cast`: - Accepts Arrow data type strings (like `arrow_cast`) - Returns NULL on cast failure instead of erroring (like `try_cast`) Implementation details: - Reuses `arrow_cast`'s `data_type_from_args` helper (made `pub(crate)`) - Simplifies to `Expr::TryCast` during optimization (vs `Expr::Cast` for `arrow_cast`) - Registered alongside existing core functions ## Are these changes tested? Yes — new sqllogictest file `arrow_try_cast.slt` covering: - Successful casts (Int64, Float64, LargeUtf8, Dictionary) - Failed cast returning NULL - Same-type passthrough - NULL input - Invalid type string errors - Multiple casts in one query ## Are there any user-facing changes? New `arrow_try_cast` SQL function available. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 56e097a commit 7f29cb0

5 files changed

Lines changed: 306 additions & 5 deletions

File tree

datafusion/functions/src/core/arrow_cast.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ impl ScalarUDFImpl for ArrowCastFunc {
163163
info: &SimplifyContext,
164164
) -> Result<ExprSimplifyResult> {
165165
// convert this into a real cast
166-
let target_type = data_type_from_args(&args)?;
166+
let target_type = data_type_from_args(self.name(), &args)?;
167167
// remove second (type) argument
168168
args.pop().unwrap();
169169
let arg = args.pop().unwrap();
@@ -189,12 +189,12 @@ impl ScalarUDFImpl for ArrowCastFunc {
189189
}
190190

191191
/// Returns the requested type from the arguments
192-
fn data_type_from_args(args: &[Expr]) -> Result<DataType> {
193-
let [_, type_arg] = take_function_args("arrow_cast", args)?;
192+
pub(crate) fn data_type_from_args(name: &str, args: &[Expr]) -> Result<DataType> {
193+
let [_, type_arg] = take_function_args(name, args)?;
194194

195195
let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else {
196196
return exec_err!(
197-
"arrow_cast requires its second argument to be a constant string, got {:?}",
197+
"{name} requires its second argument to be a constant string, got {:?}",
198198
type_arg
199199
);
200200
};
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`ArrowTryCastFunc`]: Implementation of the `arrow_try_cast`
19+
20+
use arrow::datatypes::{DataType, Field, FieldRef};
21+
use arrow::error::ArrowError;
22+
use datafusion_common::{
23+
Result, arrow_datafusion_err, datatype::DataTypeExt, exec_datafusion_err, exec_err,
24+
internal_err, types::logical_string, utils::take_function_args,
25+
};
26+
use std::any::Any;
27+
28+
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
29+
use datafusion_expr::{
30+
Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
31+
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
32+
};
33+
use datafusion_macros::user_doc;
34+
35+
use super::arrow_cast::data_type_from_args;
36+
37+
/// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring.
38+
///
39+
/// This is implemented by simplifying `arrow_try_cast(expr, 'Type')` into
40+
/// `Expr::TryCast` during optimization.
41+
#[user_doc(
42+
doc_section(label = "Other Functions"),
43+
description = "Casts a value to a specific Arrow data type, returning NULL if the cast fails.",
44+
syntax_example = "arrow_try_cast(expression, datatype)",
45+
sql_example = r#"```sql
46+
> select arrow_try_cast('123', 'Int64') as a,
47+
arrow_try_cast('not_a_number', 'Int64') as b;
48+
49+
+-----+------+
50+
| a | b |
51+
+-----+------+
52+
| 123 | NULL |
53+
+-----+------+
54+
```"#,
55+
argument(
56+
name = "expression",
57+
description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
58+
),
59+
argument(
60+
name = "datatype",
61+
description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]"
62+
)
63+
)]
64+
#[derive(Debug, PartialEq, Eq, Hash)]
65+
pub struct ArrowTryCastFunc {
66+
signature: Signature,
67+
}
68+
69+
impl Default for ArrowTryCastFunc {
70+
fn default() -> Self {
71+
Self::new()
72+
}
73+
}
74+
75+
impl ArrowTryCastFunc {
76+
pub fn new() -> Self {
77+
Self {
78+
signature: Signature::coercible(
79+
vec![
80+
Coercion::new_exact(TypeSignatureClass::Any),
81+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
82+
],
83+
Volatility::Immutable,
84+
),
85+
}
86+
}
87+
}
88+
89+
impl ScalarUDFImpl for ArrowTryCastFunc {
90+
fn as_any(&self) -> &dyn Any {
91+
self
92+
}
93+
94+
fn name(&self) -> &str {
95+
"arrow_try_cast"
96+
}
97+
98+
fn signature(&self) -> &Signature {
99+
&self.signature
100+
}
101+
102+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
103+
internal_err!("return_field_from_args should be called instead")
104+
}
105+
106+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
107+
// TryCast can always return NULL (on cast failure), so always nullable
108+
let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;
109+
110+
type_arg
111+
.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
112+
.map_or_else(
113+
|| {
114+
exec_err!(
115+
"{} requires its second argument to be a non-empty constant string",
116+
self.name()
117+
)
118+
},
119+
|casted_type| match casted_type.parse::<DataType>() {
120+
Ok(data_type) => {
121+
Ok(Field::new(self.name(), data_type, true).into())
122+
}
123+
Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
124+
Err(e) => Err(arrow_datafusion_err!(e)),
125+
},
126+
)
127+
}
128+
129+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
130+
internal_err!("arrow_try_cast should have been simplified to try_cast")
131+
}
132+
133+
fn simplify(
134+
&self,
135+
mut args: Vec<Expr>,
136+
info: &SimplifyContext,
137+
) -> Result<ExprSimplifyResult> {
138+
let target_type = data_type_from_args(self.name(), &args)?;
139+
// remove second (type) argument
140+
args.pop().unwrap();
141+
let arg = args.pop().unwrap();
142+
143+
let source_type = info.get_data_type(&arg)?;
144+
let new_expr = if source_type == target_type {
145+
arg
146+
} else {
147+
Expr::TryCast(datafusion_expr::TryCast {
148+
expr: Box::new(arg),
149+
field: target_type.into_nullable_field_ref(),
150+
})
151+
};
152+
Ok(ExprSimplifyResult::Simplified(new_expr))
153+
}
154+
155+
fn documentation(&self) -> Option<&Documentation> {
156+
self.doc()
157+
}
158+
}

datafusion/functions/src/core/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use std::sync::Arc;
2222

2323
pub mod arrow_cast;
2424
pub mod arrow_metadata;
25+
pub mod arrow_try_cast;
2526
pub mod arrowtypeof;
2627
pub mod coalesce;
2728
pub mod expr_ext;
@@ -42,6 +43,7 @@ pub mod version;
4243

4344
// create UDFs
4445
make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast);
46+
make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast);
4547
make_udf_function!(nullif::NullIfFunc, nullif);
4648
make_udf_function!(nvl::NVLFunc, nvl);
4749
make_udf_function!(nvl2::NVL2Func, nvl2);
@@ -67,7 +69,11 @@ pub mod expr_fn {
6769
arg1 arg2
6870
),(
6971
arrow_cast,
70-
"Returns value2 if value1 is NULL; otherwise it returns value1",
72+
"Casts a value to a specific Arrow data type",
73+
arg1 arg2
74+
),(
75+
arrow_try_cast,
76+
"Casts a value to a specific Arrow data type, returning NULL if the cast fails",
7177
arg1 arg2
7278
),(
7379
nvl,
@@ -140,6 +146,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
140146
vec![
141147
nullif(),
142148
arrow_cast(),
149+
arrow_try_cast(),
143150
arrow_metadata(),
144151
nvl(),
145152
nvl2(),
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
##########
19+
# Tests for arrow_try_cast: like arrow_cast but returns NULL on cast failure
20+
##########
21+
22+
# Successful cast to Float64
23+
query R
24+
select arrow_try_cast(123, 'Float64');
25+
----
26+
123
27+
28+
# Successful cast to Int64
29+
query I
30+
select arrow_try_cast('123', 'Int64');
31+
----
32+
123
33+
34+
# Failed cast returns NULL
35+
query I
36+
select arrow_try_cast('not_a_number', 'Int64');
37+
----
38+
NULL
39+
40+
# Same-type passthrough
41+
query I
42+
select arrow_try_cast(1, 'Int32');
43+
----
44+
1
45+
46+
# Cast to LargeUtf8
47+
query T
48+
select arrow_try_cast('foo', 'LargeUtf8');
49+
----
50+
foo
51+
52+
# Cast integer to string
53+
query T
54+
select arrow_try_cast(42, 'Utf8');
55+
----
56+
42
57+
58+
# Cast to dictionary type
59+
query T
60+
select arrow_try_cast('bar', 'Dictionary(Int32, Utf8)');
61+
----
62+
bar
63+
64+
# NULL input stays NULL
65+
query I
66+
select arrow_try_cast(NULL, 'Int64');
67+
----
68+
NULL
69+
70+
# Error on invalid type string
71+
statement error
72+
select arrow_try_cast(1, 'NotAType');
73+
74+
# Error when second argument is not a string constant
75+
statement error
76+
select arrow_try_cast(1, 123);
77+
78+
# Multiple arrow_try_cast in one query
79+
query IT
80+
select arrow_try_cast('456', 'Int64') as a,
81+
arrow_try_cast(789, 'Utf8') as b;
82+
----
83+
456 789
84+
85+
# Tests that exercise physical execution (not constant folding)
86+
87+
# Cast column values to Int64, with mixed valid/null/invalid inputs
88+
query I
89+
select arrow_try_cast(a, 'Int64') from (values('100'), (NULL), ('foo')) t(a);
90+
----
91+
100
92+
NULL
93+
NULL
94+
95+
# Cast column values to Float64
96+
query R
97+
select arrow_try_cast(a, 'Float64') from (values('3.14'), ('not_num'), (NULL)) t(a);
98+
----
99+
3.14
100+
NULL
101+
NULL
102+
103+
# Cast integer column to Utf8
104+
query T
105+
select arrow_try_cast(a, 'Utf8') from (values(1), (2), (NULL)) t(a);
106+
----
107+
1
108+
2
109+
NULL

docs/source/user-guide/sql/scalar_functions.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5185,6 +5185,7 @@ union_tag(union_expression)
51855185

51865186
- [arrow_cast](#arrow_cast)
51875187
- [arrow_metadata](#arrow_metadata)
5188+
- [arrow_try_cast](#arrow_try_cast)
51885189
- [arrow_typeof](#arrow_typeof)
51895190
- [get_field](#get_field)
51905191
- [version](#version)
@@ -5257,6 +5258,32 @@ arrow_metadata(expression[, key])
52575258
+-------------------------------+
52585259
```
52595260

5261+
### `arrow_try_cast`
5262+
5263+
Casts a value to a specific Arrow data type, returning NULL if the cast fails.
5264+
5265+
```sql
5266+
arrow_try_cast(expression, datatype)
5267+
```
5268+
5269+
#### Arguments
5270+
5271+
- **expression**: Expression to cast. The expression can be a constant, column, or function, and any combination of operators.
5272+
- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]
5273+
5274+
#### Example
5275+
5276+
```sql
5277+
> select arrow_try_cast('123', 'Int64') as a,
5278+
arrow_try_cast('not_a_number', 'Int64') as b;
5279+
5280+
+-----+------+
5281+
| a | b |
5282+
+-----+------+
5283+
| 123 | NULL |
5284+
+-----+------+
5285+
```
5286+
52605287
### `arrow_typeof`
52615288

52625289
Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression.

0 commit comments

Comments
 (0)