Skip to content

Commit 95fc886

Browse files
committed
Improve cast lowering to preserve field metadata
Update cast handling in planner.rs to retain logical target FieldRef metadata during the cast lowering process. Introduced a new field-aware helper in cast.rs for this purpose, which ensures that metadata, name, and nullability intent are kept intact, even for same-type casts. Updated planner tests to verify: - Preservation of metadata and nullability in lowered casts - CastExpr production for same-type casts with different field semantics - Regression checks for standard non-metadata casts - Enhanced CAST behavior to retain extension metadata while ensuring TRY_CAST still rejects it.
1 parent fd97799 commit 95fc886

3 files changed

Lines changed: 141 additions & 37 deletions

File tree

datafusion/physical-expr/src/expressions/cast.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,22 +319,51 @@ pub fn cast_with_options(
319319
input_schema: &Schema,
320320
cast_type: DataType,
321321
cast_options: Option<CastOptions<'static>>,
322+
) -> Result<Arc<dyn PhysicalExpr>> {
323+
cast_with_target_field_and_options(
324+
expr,
325+
input_schema,
326+
cast_type.into_nullable_field_ref(),
327+
cast_options,
328+
)
329+
}
330+
331+
/// Return a [`PhysicalExpr`] representing `expr` casted to `target_field`,
332+
/// preserving the field metadata in the resulting expression.
333+
pub fn cast_with_target_field_and_options(
334+
expr: Arc<dyn PhysicalExpr>,
335+
input_schema: &Schema,
336+
target_field: FieldRef,
337+
cast_options: Option<CastOptions<'static>>,
322338
) -> Result<Arc<dyn PhysicalExpr>> {
323339
let expr_type = expr.data_type(input_schema)?;
340+
let cast_type = target_field.data_type().clone();
324341
if expr_type == cast_type {
325-
Ok(Arc::clone(&expr))
342+
Ok(Arc::new(CastExpr::new_with_target_field(
343+
expr,
344+
target_field,
345+
cast_options,
346+
)))
326347
} else if matches!((&expr_type, &cast_type), (Struct(_), Struct(_))) {
327348
if can_cast_struct_types(&expr_type, &cast_type) {
328349
// Allow struct-to-struct casts that pass name-based compatibility validation.
329350
// This validation is applied at planning time (now) to fail fast, rather than
330351
// deferring errors to execution time. The name-based casting logic will be
331352
// executed at runtime via ColumnarValue::cast_to.
332-
Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
353+
Ok(Arc::new(CastExpr::new_with_target_field(
354+
expr,
355+
target_field,
356+
cast_options,
357+
)))
333358
} else {
334359
not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}")
335360
}
336361
} else if can_cast_types(&expr_type, &cast_type) {
337-
Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
362+
Ok(Arc::new(CastExpr::new_with_target_field(
363+
expr,
364+
target_field,
365+
cast_options,
366+
)))
338367
} else {
339368
not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}")
340369
}

datafusion/physical-expr/src/expressions/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pub use crate::aggregate::stats::StatsType;
4141

4242
pub use binary::{BinaryExpr, binary, similar_to};
4343
pub use case::{CaseExpr, case};
44-
pub use cast::{CastExpr, cast};
44+
pub use cast::{CastExpr, cast, cast_with_target_field_and_options};
4545
pub use cast_column::CastColumnExpr;
4646
pub use column::{Column, col, with_new_schema};
4747
pub use datafusion_expr::utils::format_state_name;

datafusion/physical-expr/src/planner.rs

Lines changed: 108 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -289,22 +289,11 @@ pub fn create_physical_expr(
289289
Ok(expressions::case(expr, when_then_expr, else_expr)?)
290290
}
291291
Expr::Cast(Cast { expr, field }) => {
292-
if !field.metadata().is_empty() {
293-
let (_, src_field) = expr.to_field(input_dfschema)?;
294-
return plan_err!(
295-
"Cast from {} to {} is not supported",
296-
format_type_and_metadata(
297-
src_field.data_type(),
298-
Some(src_field.metadata()),
299-
),
300-
format_type_and_metadata(field.data_type(), Some(field.metadata()))
301-
);
302-
}
303-
304-
expressions::cast(
292+
expressions::cast_with_target_field_and_options(
305293
create_physical_expr(expr, input_dfschema, execution_props)?,
306294
input_schema,
307-
field.data_type().clone(),
295+
Arc::clone(field),
296+
None,
308297
)
309298
}
310299
Expr::TryCast(TryCast { expr, field }) => {
@@ -443,10 +432,12 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc<dyn PhysicalExpr> {
443432

444433
#[cfg(test)]
445434
mod tests {
435+
use crate::expressions::{CastExpr, Column};
446436
use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
447437
use arrow::datatypes::{DataType, Field};
448438
use datafusion_common::datatype::DataTypeExt;
449439
use datafusion_expr::{Operator, col, lit};
440+
use std::collections::HashMap;
450441

451442
use super::*;
452443

@@ -477,39 +468,123 @@ mod tests {
477468

478469
#[test]
479470
fn test_cast_to_extension_type() -> Result<()> {
480-
let extension_field_type = Arc::new(
481-
DataType::FixedSizeBinary(16)
482-
.into_nullable_field()
483-
.with_metadata(
484-
[("ARROW:extension:name".to_string(), "arrow.uuid".to_string())]
485-
.into(),
486-
),
487-
);
488-
let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58");
471+
let extension_field_type =
472+
Arc::new(DataType::Int64.into_nullable_field().with_metadata(
473+
[("ARROW:extension:name".to_string(), "arrow.uuid".to_string())].into(),
474+
));
475+
let expr = col("value");
476+
let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
477+
let df_schema = DFSchema::try_from(schema.clone())?;
489478
let cast_expr = Expr::Cast(Cast::new_from_field(
490479
Box::new(expr.clone()),
491480
Arc::clone(&extension_field_type),
492481
));
493-
let err =
494-
create_physical_expr(&cast_expr, &DFSchema::empty(), &ExecutionProps::new())
495-
.unwrap_err();
496-
assert!(err.message().contains("arrow.uuid"));
482+
let physical_expr =
483+
create_physical_expr(&cast_expr, &df_schema, &ExecutionProps::new())?;
484+
485+
let field = physical_expr.return_field(&schema)?;
486+
assert_eq!(field.data_type(), extension_field_type.data_type());
487+
assert_eq!(field.metadata(), extension_field_type.metadata());
488+
assert!(physical_expr.nullable(&schema)?);
497489

498490
let try_cast_expr = Expr::TryCast(TryCast::new_from_field(
499491
Box::new(expr.clone()),
500492
Arc::clone(&extension_field_type),
501493
));
502-
let err = create_physical_expr(
503-
&try_cast_expr,
504-
&DFSchema::empty(),
505-
&ExecutionProps::new(),
506-
)
507-
.unwrap_err();
494+
let err =
495+
create_physical_expr(&try_cast_expr, &df_schema, &ExecutionProps::new())
496+
.unwrap_err();
508497
assert!(err.message().contains("arrow.uuid"));
509498

510499
Ok(())
511500
}
512501

502+
#[test]
503+
fn test_create_physical_expr_cast_preserves_target_field_semantics() -> Result<()> {
504+
let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
505+
let df_schema = DFSchema::try_from(schema.clone())?;
506+
let target_field = Arc::new(
507+
Field::new("logical_cast", DataType::Int64, true).with_metadata(
508+
HashMap::from([("semantic".to_string(), "target".to_string())]),
509+
),
510+
);
511+
let expr = Expr::Cast(Cast::new_from_field(
512+
Box::new(col("value")),
513+
Arc::clone(&target_field),
514+
));
515+
516+
let physical_expr =
517+
create_physical_expr(&expr, &df_schema, &ExecutionProps::new())?;
518+
let cast_expr = physical_expr
519+
.as_any()
520+
.downcast_ref::<CastExpr>()
521+
.expect("logical cast should lower to CastExpr");
522+
523+
assert_eq!(cast_expr.target_field(), &target_field);
524+
525+
let field = physical_expr.return_field(&schema)?;
526+
assert_eq!(field.name(), "logical_cast");
527+
assert_eq!(field.data_type(), &DataType::Int64);
528+
assert!(field.is_nullable());
529+
assert_eq!(
530+
field.metadata().get("semantic").map(String::as_str),
531+
Some("target")
532+
);
533+
assert!(physical_expr.nullable(&schema)?);
534+
535+
Ok(())
536+
}
537+
538+
#[test]
539+
fn test_create_physical_expr_same_type_cast_keeps_target_field() -> Result<()> {
540+
let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
541+
let df_schema = DFSchema::try_from(schema.clone())?;
542+
let target_field = Arc::new(
543+
Field::new("same_type_cast", DataType::Int32, true).with_metadata(
544+
HashMap::from([("semantic".to_string(), "same_type".to_string())]),
545+
),
546+
);
547+
let expr = Expr::Cast(Cast::new_from_field(
548+
Box::new(col("value")),
549+
Arc::clone(&target_field),
550+
));
551+
552+
let physical_expr =
553+
create_physical_expr(&expr, &df_schema, &ExecutionProps::new())?;
554+
555+
assert!(physical_expr.as_any().downcast_ref::<Column>().is_none());
556+
assert!(physical_expr.as_any().downcast_ref::<CastExpr>().is_some());
557+
558+
let field = physical_expr.return_field(&schema)?;
559+
assert_eq!(field.name(), "same_type_cast");
560+
assert_eq!(field.data_type(), &DataType::Int32);
561+
assert!(field.is_nullable());
562+
assert_eq!(
563+
field.metadata().get("semantic").map(String::as_str),
564+
Some("same_type")
565+
);
566+
567+
Ok(())
568+
}
569+
570+
#[test]
571+
fn test_create_physical_expr_standard_cast_still_validates() -> Result<()> {
572+
let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
573+
let df_schema = DFSchema::try_from(schema.clone())?;
574+
let expr = Expr::Cast(Cast::new(Box::new(col("value")), DataType::Int64));
575+
576+
let physical_expr =
577+
create_physical_expr(&expr, &df_schema, &ExecutionProps::new())?;
578+
579+
assert_eq!(physical_expr.data_type(&schema)?, DataType::Int64);
580+
assert_eq!(
581+
physical_expr.return_field(&schema)?.data_type(),
582+
&DataType::Int64
583+
);
584+
585+
Ok(())
586+
}
587+
513588
/// Test that deeply nested expressions do not cause a stack overflow.
514589
///
515590
/// This test only runs when the `recursive_protection` feature is enabled,

0 commit comments

Comments
 (0)