Skip to content

Commit 29c454b

Browse files
committed
feat: enhance TryCastExpr with target field semantics
- Added a target_field attribute to TryCastExpr to store field metadata for the desired output. - Introduced methods for creating TryCastExpr instances with and without explicit target field information, ensuring that casting semantics are preserved. - Updated the return_field and cast_type methods to reflect changes in field metadata handling. - Enhanced the try_cast and try_cast_with_target_field functions to accept and utilize the target field. - Added tests to validate the preservation of target field semantics in TryCastExpr operations and enhanced integration with the logical to physical expression conversion.
1 parent e4445db commit 29c454b

5 files changed

Lines changed: 210 additions & 57 deletions

File tree

datafusion/physical-expr/src/expressions/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,6 @@ pub use literal::{Literal, lit};
5555
pub use negative::{NegativeExpr, negative};
5656
pub use no_op::NoOp;
5757
pub use not::{NotExpr, not};
58+
pub(crate) use try_cast::try_cast_with_target_field;
5859
pub use try_cast::{TryCastExpr, try_cast};
5960
pub use unknown_column::UnKnownColumn;

datafusion/physical-expr/src/expressions/try_cast.rs

Lines changed: 118 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use arrow::compute::CastOptions;
2626
use arrow::datatypes::{DataType, FieldRef, Schema};
2727
use arrow::record_batch::RecordBatch;
2828
use compute::can_cast_types;
29+
use datafusion_common::datatype::DataTypeExt;
2930
use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
3031
use datafusion_common::{Result, not_impl_err};
3132
use datafusion_expr::ColumnarValue;
@@ -35,28 +36,36 @@ use datafusion_expr::ColumnarValue;
3536
pub struct TryCastExpr {
3637
/// The expression to cast
3738
expr: Arc<dyn PhysicalExpr>,
38-
/// The data type to cast to
39-
cast_type: DataType,
39+
/// Field metadata describing the desired output after casting
40+
target_field: FieldRef,
4041
}
4142

4243
// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
4344
impl PartialEq for TryCastExpr {
4445
fn eq(&self, other: &Self) -> bool {
45-
self.expr.eq(&other.expr) && self.cast_type == other.cast_type
46+
self.expr.eq(&other.expr) && self.target_field == other.target_field
4647
}
4748
}
4849

4950
impl Hash for TryCastExpr {
5051
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
5152
self.expr.hash(state);
52-
self.cast_type.hash(state);
53+
self.target_field.hash(state);
5354
}
5455
}
5556

5657
impl TryCastExpr {
5758
/// Create a new CastExpr
5859
pub fn new(expr: Arc<dyn PhysicalExpr>, cast_type: DataType) -> Self {
59-
Self { expr, cast_type }
60+
Self::new_with_target_field(expr, cast_type.into_nullable_field_ref())
61+
}
62+
63+
/// Create a new TryCastExpr with an explicit target field.
64+
pub fn new_with_target_field(
65+
expr: Arc<dyn PhysicalExpr>,
66+
target_field: FieldRef,
67+
) -> Self {
68+
Self { expr, target_field }
6069
}
6170

6271
/// The expression to cast
@@ -66,13 +75,45 @@ impl TryCastExpr {
6675

6776
/// The data type to cast to
6877
pub fn cast_type(&self) -> &DataType {
69-
&self.cast_type
78+
self.target_field.data_type()
79+
}
80+
81+
/// Field metadata describing the output column after casting.
82+
pub fn target_field(&self) -> &FieldRef {
83+
&self.target_field
84+
}
85+
86+
fn is_default_target_field(&self) -> bool {
87+
self.target_field.name().is_empty()
88+
&& self.target_field.is_nullable()
89+
&& self.target_field.metadata().is_empty()
90+
}
91+
92+
fn resolved_target_field(&self, input_schema: &Schema) -> Result<FieldRef> {
93+
if self.is_default_target_field() {
94+
self.expr.return_field(input_schema).map(|field| {
95+
Arc::new(
96+
field
97+
.as_ref()
98+
.clone()
99+
.with_data_type(self.cast_type().clone())
100+
.with_nullable(true),
101+
)
102+
})
103+
} else {
104+
Ok(Arc::clone(&self.target_field))
105+
}
106+
}
107+
108+
fn preserves_child_field_semantics(&self, input_schema: &Schema) -> Result<bool> {
109+
Ok(self.resolved_target_field(input_schema)?
110+
== self.expr.return_field(input_schema)?)
70111
}
71112
}
72113

73114
impl fmt::Display for TryCastExpr {
74115
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
75-
write!(f, "TRY_CAST({} AS {})", self.expr, self.cast_type)
116+
write!(f, "TRY_CAST({} AS {})", self.expr, self.cast_type())
76117
}
77118
}
78119

@@ -83,7 +124,7 @@ impl PhysicalExpr for TryCastExpr {
83124
}
84125

85126
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
86-
Ok(self.cast_type.clone())
127+
Ok(self.cast_type().clone())
87128
}
88129

89130
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
@@ -96,14 +137,11 @@ impl PhysicalExpr for TryCastExpr {
96137
safe: true,
97138
format_options: DEFAULT_FORMAT_OPTIONS,
98139
};
99-
value.cast_to(&self.cast_type, Some(&options))
140+
value.cast_to(self.cast_type(), Some(&options))
100141
}
101142

102143
fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
103-
self.expr
104-
.return_field(input_schema)
105-
.map(|f| f.as_ref().clone().with_data_type(self.cast_type.clone()))
106-
.map(Arc::new)
144+
self.resolved_target_field(input_schema)
107145
}
108146

109147
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
@@ -114,16 +152,16 @@ impl PhysicalExpr for TryCastExpr {
114152
self: Arc<Self>,
115153
children: Vec<Arc<dyn PhysicalExpr>>,
116154
) -> Result<Arc<dyn PhysicalExpr>> {
117-
Ok(Arc::new(TryCastExpr::new(
155+
Ok(Arc::new(TryCastExpr::new_with_target_field(
118156
Arc::clone(&children[0]),
119-
self.cast_type.clone(),
157+
Arc::clone(&self.target_field),
120158
)))
121159
}
122160

123161
fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124162
write!(f, "TRY_CAST(")?;
125163
self.expr.fmt_sql(f)?;
126-
write!(f, " AS {:?})", self.cast_type)
164+
write!(f, " AS {:?})", self.cast_type())
127165
}
128166
}
129167

@@ -135,12 +173,26 @@ pub fn try_cast(
135173
expr: Arc<dyn PhysicalExpr>,
136174
input_schema: &Schema,
137175
cast_type: DataType,
176+
) -> Result<Arc<dyn PhysicalExpr>> {
177+
try_cast_with_target_field(expr, input_schema, cast_type.into_nullable_field_ref())
178+
}
179+
180+
pub(crate) fn try_cast_with_target_field(
181+
expr: Arc<dyn PhysicalExpr>,
182+
input_schema: &Schema,
183+
target_field: FieldRef,
138184
) -> Result<Arc<dyn PhysicalExpr>> {
139185
let expr_type = expr.data_type(input_schema)?;
140-
if expr_type == cast_type {
141-
Ok(Arc::clone(&expr))
142-
} else if can_cast_types(&expr_type, &cast_type) {
143-
Ok(Arc::new(TryCastExpr::new(expr, cast_type)))
186+
let cast_type = target_field.data_type().clone();
187+
let candidate = TryCastExpr::new_with_target_field(Arc::clone(&expr), target_field);
188+
189+
if expr_type == cast_type
190+
&& (candidate.is_default_target_field()
191+
|| candidate.preserves_child_field_semantics(input_schema)?)
192+
{
193+
Ok(expr)
194+
} else if expr_type == cast_type || can_cast_types(&expr_type, &cast_type) {
195+
Ok(Arc::new(candidate))
144196
} else {
145197
not_impl_err!("Unsupported TRY_CAST from {expr_type} to {cast_type}")
146198
}
@@ -161,6 +213,7 @@ mod tests {
161213
datatypes::*,
162214
};
163215
use datafusion_physical_expr_common::physical_expr::fmt_sql;
216+
use std::collections::HashMap;
164217

165218
// runs an end-to-end test of physical type cast
166219
// 1. construct a record batch with a column "a" of type A
@@ -564,6 +617,51 @@ mod tests {
564617
result.expect_err("expected Invalid TRY_CAST");
565618
}
566619

620+
#[test]
621+
fn field_aware_same_type_try_cast_preserves_explicit_target_field() -> Result<()> {
622+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
623+
let a_col = col("a", &schema)?;
624+
let logical_field = Arc::new(
625+
Field::new("logical_a", DataType::Int32, true).with_metadata(HashMap::from(
626+
[("target_meta".to_string(), "1".to_string())],
627+
)),
628+
);
629+
let expr = try_cast_with_target_field(a_col, &schema, logical_field)?;
630+
631+
let try_cast_expr = expr
632+
.as_any()
633+
.downcast_ref::<TryCastExpr>()
634+
.expect("explicit same-type target should preserve TryCastExpr");
635+
let field = try_cast_expr.return_field(&schema)?;
636+
637+
assert_eq!(field.name(), "logical_a");
638+
assert!(field.is_nullable());
639+
assert_eq!(
640+
field.metadata().get("target_meta").map(String::as_str),
641+
Some("1")
642+
);
643+
assert!(expr.nullable(&schema)?);
644+
645+
Ok(())
646+
}
647+
648+
#[test]
649+
fn default_same_type_try_cast_is_elided() -> Result<()> {
650+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
651+
let a_col = col("a", &schema)?;
652+
let target_field = DataType::Int32.into_nullable_field_ref();
653+
let expr = try_cast_with_target_field(a_col, &schema, target_field)?;
654+
655+
assert!(
656+
expr.as_any()
657+
.downcast_ref::<crate::expressions::Column>()
658+
.is_some()
659+
);
660+
assert!(expr.as_any().downcast_ref::<TryCastExpr>().is_none());
661+
662+
Ok(())
663+
}
664+
567665
// create decimal array with the specified precision and scale
568666
fn create_decimal_array(array: &[i128], precision: u8, scale: i8) -> Decimal128Array {
569667
let mut decimal_builder = Decimal128Builder::with_capacity(array.len());

datafusion/physical-expr/src/planner.rs

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{
2525

2626
use arrow::datatypes::Schema;
2727
use datafusion_common::config::ConfigOptions;
28-
use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata};
28+
use datafusion_common::metadata::FieldMetadata;
2929
use datafusion_common::{
3030
DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err, plan_err,
3131
};
@@ -34,7 +34,7 @@ use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction};
3434
use datafusion_expr::var_provider::VarType;
3535
use datafusion_expr::var_provider::is_system_variables;
3636
use datafusion_expr::{
37-
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, binary_expr, lit,
37+
Between, BinaryExpr, Expr, Like, Operator, TryCast, binary_expr, lit,
3838
};
3939

4040
/// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1
@@ -298,22 +298,10 @@ pub fn create_physical_expr(
298298
)
299299
}
300300
Expr::TryCast(TryCast { expr, field }) => {
301-
if !field.metadata().is_empty() {
302-
let (_, src_field) = expr.to_field(input_dfschema)?;
303-
return plan_err!(
304-
"TryCast from {} to {} is not supported",
305-
format_type_and_metadata(
306-
src_field.data_type(),
307-
Some(src_field.metadata()),
308-
),
309-
format_type_and_metadata(field.data_type(), Some(field.metadata()))
310-
);
311-
}
312-
313-
expressions::try_cast(
301+
expressions::try_cast_with_target_field(
314302
create_physical_expr(expr, input_dfschema, execution_props)?,
315303
input_schema,
316-
field.data_type().clone(),
304+
Arc::clone(field),
317305
)
318306
}
319307
Expr::Not(expr) => {
@@ -433,7 +421,7 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc<dyn PhysicalExpr> {
433421

434422
#[cfg(test)]
435423
mod tests {
436-
use crate::expressions::{CastExpr, Column};
424+
use crate::expressions::{CastExpr, Column, TryCastExpr};
437425
use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
438426
use arrow::datatypes::{DataType, Field};
439427
use datafusion_common::datatype::DataTypeExt;
@@ -498,10 +486,12 @@ mod tests {
498486
Box::new(expr.clone()),
499487
Arc::clone(&extension_field_type),
500488
));
501-
let err =
502-
create_physical_expr(&try_cast_expr, &df_schema, &ExecutionProps::new())
503-
.unwrap_err();
504-
assert!(err.message().contains(extension_name));
489+
let physical_expr =
490+
create_physical_expr(&try_cast_expr, &df_schema, &ExecutionProps::new())?;
491+
let field = physical_expr.return_field(&schema)?;
492+
assert_eq!(field.data_type(), extension_field_type.data_type());
493+
assert_eq!(field.metadata(), extension_field_type.metadata());
494+
assert!(physical_expr.nullable(&schema)?);
505495

506496
Ok(())
507497
}
@@ -607,6 +597,43 @@ mod tests {
607597
Ok(())
608598
}
609599

600+
#[test]
601+
fn test_create_physical_expr_try_cast_preserves_target_field_semantics() -> Result<()>
602+
{
603+
let schema = Schema::new(vec![Field::new("value", DataType::Utf8, false)]);
604+
let df_schema = DFSchema::try_from(schema.clone())?;
605+
let target_field = Arc::new(
606+
Field::new("logical_try_cast", DataType::Int64, true).with_metadata(
607+
HashMap::from([("semantic".to_string(), "target".to_string())]),
608+
),
609+
);
610+
let expr = Expr::TryCast(TryCast::new_from_field(
611+
Box::new(col("value")),
612+
Arc::clone(&target_field),
613+
));
614+
615+
let physical_expr =
616+
create_physical_expr(&expr, &df_schema, &ExecutionProps::new())?;
617+
let try_cast_expr = physical_expr
618+
.as_any()
619+
.downcast_ref::<TryCastExpr>()
620+
.expect("logical try_cast should lower to TryCastExpr");
621+
622+
assert_eq!(try_cast_expr.target_field(), &target_field);
623+
624+
let field = physical_expr.return_field(&schema)?;
625+
assert_eq!(field.name(), "logical_try_cast");
626+
assert_eq!(field.data_type(), &DataType::Int64);
627+
assert!(field.is_nullable());
628+
assert_eq!(
629+
field.metadata().get("semantic").map(String::as_str),
630+
Some("target")
631+
);
632+
assert!(physical_expr.nullable(&schema)?);
633+
634+
Ok(())
635+
}
636+
610637
/// Test that deeply nested expressions do not cause a stack overflow.
611638
///
612639
/// This test only runs when the `recursive_protection` feature is enabled,

0 commit comments

Comments
 (0)