@@ -26,6 +26,7 @@ use arrow::compute::CastOptions;
2626use arrow:: datatypes:: { DataType , FieldRef , Schema } ;
2727use arrow:: record_batch:: RecordBatch ;
2828use compute:: can_cast_types;
29+ use datafusion_common:: datatype:: DataTypeExt ;
2930use datafusion_common:: format:: DEFAULT_FORMAT_OPTIONS ;
3031use datafusion_common:: { Result , not_impl_err} ;
3132use datafusion_expr:: ColumnarValue ;
@@ -35,28 +36,36 @@ use datafusion_expr::ColumnarValue;
3536pub struct TryCastExpr {
3637 /// The expression to cast
3738 expr : Arc < dyn PhysicalExpr > ,
38- /// The data type to cast to
39- cast_type : DataType ,
39+ /// Field metadata describing the desired output after casting
40+ target_field : FieldRef ,
4041}
4142
4243// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
4344impl PartialEq for TryCastExpr {
4445 fn eq ( & self , other : & Self ) -> bool {
45- self . expr . eq ( & other. expr ) && self . cast_type == other. cast_type
46+ self . expr . eq ( & other. expr ) && self . target_field == other. target_field
4647 }
4748}
4849
4950impl Hash for TryCastExpr {
5051 fn hash < H : std:: hash:: Hasher > ( & self , state : & mut H ) {
5152 self . expr . hash ( state) ;
52- self . cast_type . hash ( state) ;
53+ self . target_field . hash ( state) ;
5354 }
5455}
5556
5657impl TryCastExpr {
5758 /// Create a new CastExpr
5859 pub fn new ( expr : Arc < dyn PhysicalExpr > , cast_type : DataType ) -> Self {
59- Self { expr, cast_type }
60+ Self :: new_with_target_field ( expr, cast_type. into_nullable_field_ref ( ) )
61+ }
62+
63+ /// Create a new TryCastExpr with an explicit target field.
64+ pub fn new_with_target_field (
65+ expr : Arc < dyn PhysicalExpr > ,
66+ target_field : FieldRef ,
67+ ) -> Self {
68+ Self { expr, target_field }
6069 }
6170
6271 /// The expression to cast
@@ -66,13 +75,45 @@ impl TryCastExpr {
6675
6776 /// The data type to cast to
6877 pub fn cast_type ( & self ) -> & DataType {
69- & self . cast_type
78+ self . target_field . data_type ( )
79+ }
80+
81+ /// Field metadata describing the output column after casting.
82+ pub fn target_field ( & self ) -> & FieldRef {
83+ & self . target_field
84+ }
85+
86+ fn is_default_target_field ( & self ) -> bool {
87+ self . target_field . name ( ) . is_empty ( )
88+ && self . target_field . is_nullable ( )
89+ && self . target_field . metadata ( ) . is_empty ( )
90+ }
91+
92+ fn resolved_target_field ( & self , input_schema : & Schema ) -> Result < FieldRef > {
93+ if self . is_default_target_field ( ) {
94+ self . expr . return_field ( input_schema) . map ( |field| {
95+ Arc :: new (
96+ field
97+ . as_ref ( )
98+ . clone ( )
99+ . with_data_type ( self . cast_type ( ) . clone ( ) )
100+ . with_nullable ( true ) ,
101+ )
102+ } )
103+ } else {
104+ Ok ( Arc :: clone ( & self . target_field ) )
105+ }
106+ }
107+
108+ fn preserves_child_field_semantics ( & self , input_schema : & Schema ) -> Result < bool > {
109+ Ok ( self . resolved_target_field ( input_schema) ?
110+ == self . expr . return_field ( input_schema) ?)
70111 }
71112}
72113
73114impl fmt:: Display for TryCastExpr {
74115 fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
75- write ! ( f, "TRY_CAST({} AS {})" , self . expr, self . cast_type)
116+ write ! ( f, "TRY_CAST({} AS {})" , self . expr, self . cast_type( ) )
76117 }
77118}
78119
@@ -83,7 +124,7 @@ impl PhysicalExpr for TryCastExpr {
83124 }
84125
85126 fn data_type ( & self , _input_schema : & Schema ) -> Result < DataType > {
86- Ok ( self . cast_type . clone ( ) )
127+ Ok ( self . cast_type ( ) . clone ( ) )
87128 }
88129
89130 fn nullable ( & self , _input_schema : & Schema ) -> Result < bool > {
@@ -96,14 +137,11 @@ impl PhysicalExpr for TryCastExpr {
96137 safe : true ,
97138 format_options : DEFAULT_FORMAT_OPTIONS ,
98139 } ;
99- value. cast_to ( & self . cast_type , Some ( & options) )
140+ value. cast_to ( self . cast_type ( ) , Some ( & options) )
100141 }
101142
102143 fn return_field ( & self , input_schema : & Schema ) -> Result < FieldRef > {
103- self . expr
104- . return_field ( input_schema)
105- . map ( |f| f. as_ref ( ) . clone ( ) . with_data_type ( self . cast_type . clone ( ) ) )
106- . map ( Arc :: new)
144+ self . resolved_target_field ( input_schema)
107145 }
108146
109147 fn children ( & self ) -> Vec < & Arc < dyn PhysicalExpr > > {
@@ -114,16 +152,16 @@ impl PhysicalExpr for TryCastExpr {
114152 self : Arc < Self > ,
115153 children : Vec < Arc < dyn PhysicalExpr > > ,
116154 ) -> Result < Arc < dyn PhysicalExpr > > {
117- Ok ( Arc :: new ( TryCastExpr :: new (
155+ Ok ( Arc :: new ( TryCastExpr :: new_with_target_field (
118156 Arc :: clone ( & children[ 0 ] ) ,
119- self . cast_type . clone ( ) ,
157+ Arc :: clone ( & self . target_field ) ,
120158 ) ) )
121159 }
122160
123161 fn fmt_sql ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
124162 write ! ( f, "TRY_CAST(" ) ?;
125163 self . expr . fmt_sql ( f) ?;
126- write ! ( f, " AS {:?})" , self . cast_type)
164+ write ! ( f, " AS {:?})" , self . cast_type( ) )
127165 }
128166}
129167
@@ -135,12 +173,26 @@ pub fn try_cast(
135173 expr : Arc < dyn PhysicalExpr > ,
136174 input_schema : & Schema ,
137175 cast_type : DataType ,
176+ ) -> Result < Arc < dyn PhysicalExpr > > {
177+ try_cast_with_target_field ( expr, input_schema, cast_type. into_nullable_field_ref ( ) )
178+ }
179+
180+ pub ( crate ) fn try_cast_with_target_field (
181+ expr : Arc < dyn PhysicalExpr > ,
182+ input_schema : & Schema ,
183+ target_field : FieldRef ,
138184) -> Result < Arc < dyn PhysicalExpr > > {
139185 let expr_type = expr. data_type ( input_schema) ?;
140- if expr_type == cast_type {
141- Ok ( Arc :: clone ( & expr) )
142- } else if can_cast_types ( & expr_type, & cast_type) {
143- Ok ( Arc :: new ( TryCastExpr :: new ( expr, cast_type) ) )
186+ let cast_type = target_field. data_type ( ) . clone ( ) ;
187+ let candidate = TryCastExpr :: new_with_target_field ( Arc :: clone ( & expr) , target_field) ;
188+
189+ if expr_type == cast_type
190+ && ( candidate. is_default_target_field ( )
191+ || candidate. preserves_child_field_semantics ( input_schema) ?)
192+ {
193+ Ok ( expr)
194+ } else if expr_type == cast_type || can_cast_types ( & expr_type, & cast_type) {
195+ Ok ( Arc :: new ( candidate) )
144196 } else {
145197 not_impl_err ! ( "Unsupported TRY_CAST from {expr_type} to {cast_type}" )
146198 }
@@ -161,6 +213,7 @@ mod tests {
161213 datatypes:: * ,
162214 } ;
163215 use datafusion_physical_expr_common:: physical_expr:: fmt_sql;
216+ use std:: collections:: HashMap ;
164217
165218 // runs an end-to-end test of physical type cast
166219 // 1. construct a record batch with a column "a" of type A
@@ -564,6 +617,51 @@ mod tests {
564617 result. expect_err ( "expected Invalid TRY_CAST" ) ;
565618 }
566619
620+ #[ test]
621+ fn field_aware_same_type_try_cast_preserves_explicit_target_field ( ) -> Result < ( ) > {
622+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ;
623+ let a_col = col ( "a" , & schema) ?;
624+ let logical_field = Arc :: new (
625+ Field :: new ( "logical_a" , DataType :: Int32 , true ) . with_metadata ( HashMap :: from (
626+ [ ( "target_meta" . to_string ( ) , "1" . to_string ( ) ) ] ,
627+ ) ) ,
628+ ) ;
629+ let expr = try_cast_with_target_field ( a_col, & schema, logical_field) ?;
630+
631+ let try_cast_expr = expr
632+ . as_any ( )
633+ . downcast_ref :: < TryCastExpr > ( )
634+ . expect ( "explicit same-type target should preserve TryCastExpr" ) ;
635+ let field = try_cast_expr. return_field ( & schema) ?;
636+
637+ assert_eq ! ( field. name( ) , "logical_a" ) ;
638+ assert ! ( field. is_nullable( ) ) ;
639+ assert_eq ! (
640+ field. metadata( ) . get( "target_meta" ) . map( String :: as_str) ,
641+ Some ( "1" )
642+ ) ;
643+ assert ! ( expr. nullable( & schema) ?) ;
644+
645+ Ok ( ( ) )
646+ }
647+
648+ #[ test]
649+ fn default_same_type_try_cast_is_elided ( ) -> Result < ( ) > {
650+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ;
651+ let a_col = col ( "a" , & schema) ?;
652+ let target_field = DataType :: Int32 . into_nullable_field_ref ( ) ;
653+ let expr = try_cast_with_target_field ( a_col, & schema, target_field) ?;
654+
655+ assert ! (
656+ expr. as_any( )
657+ . downcast_ref:: <crate :: expressions:: Column >( )
658+ . is_some( )
659+ ) ;
660+ assert ! ( expr. as_any( ) . downcast_ref:: <TryCastExpr >( ) . is_none( ) ) ;
661+
662+ Ok ( ( ) )
663+ }
664+
567665 // create decimal array with the specified precision and scale
568666 fn create_decimal_array ( array : & [ i128 ] , precision : u8 , scale : i8 ) -> Decimal128Array {
569667 let mut decimal_builder = Decimal128Builder :: with_capacity ( array. len ( ) ) ;
0 commit comments