Skip to content

Commit b81b419

Browse files
committed
Implement improved handling of same-type casts
Refine cast handling in both cast.rs and schema_rewriter.rs. Explicit same-type casts now preserve CastExpr semantics, while default type-only casts are elided. Update planner tests to properly distinguish between the two and ensure consistent unified behavior across adapters. Added low-level tests for preserved and elided same-type cases for better coverage.
1 parent a878670 commit b81b419

3 files changed

Lines changed: 116 additions & 47 deletions

File tree

datafusion/physical-expr-adapter/src/schema_rewriter.rs

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ use datafusion_common::{
3535
};
3636
use datafusion_functions::core::getfield::GetFieldFunc;
3737
use datafusion_physical_expr::PhysicalExprSimplifier;
38-
use datafusion_physical_expr::expressions::CastColumnExpr;
3938
use datafusion_physical_expr::projection::{ProjectionExprs, Projector};
4039
use datafusion_physical_expr::{
4140
ScalarFunctionExpr,
42-
expressions::{self, Column},
41+
expressions::{self, CastExpr, Column},
4342
};
4443
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
4544
use itertools::Itertools;
@@ -440,7 +439,7 @@ impl DefaultPhysicalExprAdapterRewriter {
440439
// TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123`
441440
// since that's much cheaper to evalaute.
442441
// See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928
443-
self.create_cast_column_expr(resolved_column, physical_field, logical_field)
442+
self.create_cast_expr(resolved_column, physical_field, logical_field)
444443
}
445444

446445
/// Resolves a logical column to the corresponding physical column and field.
@@ -476,12 +475,12 @@ impl DefaultPhysicalExprAdapterRewriter {
476475
)))
477476
}
478477

479-
/// Validates type compatibility and creates a CastColumnExpr if needed.
478+
/// Validates type compatibility and creates a field-aware CastExpr if needed.
480479
///
481480
/// Checks whether the physical field can be cast to the logical field type,
482-
/// handling both struct and scalar types. Returns a CastColumnExpr with the
483-
/// appropriate configuration.
484-
fn create_cast_column_expr(
481+
/// handling both struct and scalar types. Returns a CastExpr with the
482+
/// appropriate logical target field configuration.
483+
fn create_cast_expr(
485484
&self,
486485
column: Column,
487486
physical_field: FieldRef,
@@ -513,9 +512,15 @@ impl DefaultPhysicalExprAdapterRewriter {
513512
}
514513
}
515514

516-
let cast_expr = Arc::new(CastColumnExpr::new(
515+
let physical_column_index = self.physical_file_schema.index_of(column.name())?;
516+
let column = if column.index() == physical_column_index {
517+
column
518+
} else {
519+
Column::new_with_schema(column.name(), self.physical_file_schema.as_ref())?
520+
};
521+
522+
let cast_expr = Arc::new(CastExpr::new_with_target_field(
517523
Arc::new(column),
518-
physical_field,
519524
Arc::new(logical_field.clone()),
520525
None,
521526
));
@@ -669,7 +674,7 @@ mod tests {
669674
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
670675
use datafusion_common::{Result, ScalarValue, assert_contains, record_batch};
671676
use datafusion_expr::Operator;
672-
use datafusion_physical_expr::expressions::{Column, Literal, col, lit};
677+
use datafusion_physical_expr::expressions::{CastExpr, Column, Literal, col, lit};
673678
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
674679
use itertools::Itertools;
675680
use std::sync::Arc;
@@ -702,7 +707,7 @@ mod tests {
702707
let result = adapter.rewrite(column_expr).unwrap();
703708

704709
// Should be wrapped in a cast expression
705-
assert!(result.as_any().downcast_ref::<CastColumnExpr>().is_some());
710+
assert!(result.as_any().downcast_ref::<CastExpr>().is_some());
706711
}
707712

708713
#[test]
@@ -723,8 +728,8 @@ mod tests {
723728
let result = adapter.rewrite(Arc::new(Column::new("a", 0)))?;
724729
let cast = result
725730
.as_any()
726-
.downcast_ref::<CastColumnExpr>()
727-
.expect("Expected CastColumnExpr");
731+
.downcast_ref::<CastExpr>()
732+
.expect("Expected CastExpr");
728733

729734
assert_eq!(cast.target_field().data_type(), &DataType::Int64);
730735
assert!(!cast.target_field().is_nullable());
@@ -736,8 +741,10 @@ mod tests {
736741
Some("1")
737742
);
738743

739-
// Ensure the expression reports the logical nullability regardless of input schema
740-
assert!(!result.nullable(physical_schema.as_ref())?);
744+
// Runtime nullability follows the child expression, but the logical
745+
// target field nullability is still preserved via return_field().
746+
assert!(result.nullable(physical_schema.as_ref())?);
747+
assert!(!result.return_field(physical_schema.as_ref())?.is_nullable());
741748

742749
Ok(())
743750
}
@@ -772,9 +779,8 @@ mod tests {
772779
println!("Rewritten expression: {result}");
773780

774781
let expected = expressions::BinaryExpr::new(
775-
Arc::new(CastColumnExpr::new(
782+
Arc::new(CastExpr::new_with_target_field(
776783
Arc::new(Column::new("a", 0)),
777-
Arc::new(Field::new("a", DataType::Int32, false)),
778784
Arc::new(Field::new("a", DataType::Int64, false)),
779785
None,
780786
)),
@@ -860,17 +866,6 @@ mod tests {
860866

861867
let result = adapter.rewrite(column_expr).unwrap();
862868

863-
let physical_struct_fields: Fields = vec![
864-
Field::new("id", DataType::Int32, false),
865-
Field::new("name", DataType::Utf8, true),
866-
]
867-
.into();
868-
let physical_field = Arc::new(Field::new(
869-
"data",
870-
DataType::Struct(physical_struct_fields),
871-
false,
872-
));
873-
874869
let logical_struct_fields: Fields = vec![
875870
Field::new("id", DataType::Int64, false),
876871
Field::new("name", DataType::Utf8View, true),
@@ -882,9 +877,8 @@ mod tests {
882877
false,
883878
));
884879

885-
let expected = Arc::new(CastColumnExpr::new(
880+
let expected = Arc::new(CastExpr::new_with_target_field(
886881
Arc::new(Column::new("data", 0)),
887-
physical_field,
888882
logical_field,
889883
None,
890884
)) as Arc<dyn PhysicalExpr>;
@@ -1558,11 +1552,11 @@ mod tests {
15581552

15591553
let result = adapter.rewrite(column_expr).unwrap();
15601554

1561-
// Should be a CastColumnExpr
1555+
// Should be a CastExpr
15621556
let cast_expr = result
15631557
.as_any()
1564-
.downcast_ref::<CastColumnExpr>()
1565-
.expect("Expected CastColumnExpr");
1558+
.downcast_ref::<CastExpr>()
1559+
.expect("Expected CastExpr");
15661560

15671561
// Verify the inner column points to the correct physical index (1)
15681562
let inner_col = cast_expr
@@ -1581,7 +1575,7 @@ mod tests {
15811575
}
15821576

15831577
#[test]
1584-
fn test_create_cast_column_expr_uses_name_lookup_not_column_index() {
1578+
fn test_create_cast_expr_uses_name_lookup_not_column_index() {
15851579
// Physical schema has column `a` at index 1; index 0 is an incompatible type.
15861580
let physical_schema = Arc::new(Schema::new(vec![
15871581
Field::new("b", DataType::Binary, true),
@@ -1601,7 +1595,7 @@ mod tests {
16011595
// Deliberately provide the wrong index for column `a`.
16021596
// Regression: this must still resolve against physical field `a` by name.
16031597
let transformed = rewriter
1604-
.create_cast_column_expr(
1598+
.create_cast_expr(
16051599
Column::new("a", 0),
16061600
Arc::new(physical_schema.field_with_name("a").unwrap().clone()),
16071601
logical_schema.field_with_name("a").unwrap(),
@@ -1611,11 +1605,16 @@ mod tests {
16111605
let cast_expr = transformed
16121606
.data
16131607
.as_any()
1614-
.downcast_ref::<CastColumnExpr>()
1615-
.expect("Expected CastColumnExpr");
1608+
.downcast_ref::<CastExpr>()
1609+
.expect("Expected CastExpr");
16161610

1617-
assert_eq!(cast_expr.input_field().name(), "a");
1618-
assert_eq!(cast_expr.input_field().data_type(), &DataType::Int32);
1611+
let inner_col = cast_expr
1612+
.expr()
1613+
.as_any()
1614+
.downcast_ref::<Column>()
1615+
.expect("Expected inner Column");
1616+
assert_eq!(inner_col.name(), "a");
1617+
assert_eq!(inner_col.index(), 1);
16191618
assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64);
16201619
}
16211620
}

datafusion/physical-expr/src/expressions/cast.rs

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ impl CastExpr {
176176
}
177177
}
178178

179+
fn preserves_child_field_semantics(&self, input_schema: &Schema) -> Result<bool> {
180+
Ok(self.resolved_target_field(input_schema)? == self.expr.return_field(input_schema)?)
181+
}
182+
179183
/// Check if casting from the specified source type to the target type is a
180184
/// widening cast (e.g. from `Int8` to `Int16`).
181185
pub fn check_bigger_cast(cast_type: &DataType, src: &DataType) -> bool {
@@ -337,7 +341,13 @@ pub(crate) fn cast_with_target_field_and_options(
337341
) -> Result<Arc<dyn PhysicalExpr>> {
338342
let expr_type = expr.data_type(input_schema)?;
339343
let cast_type = target_field.data_type();
340-
if expr_type == *cast_type {
344+
let candidate = CastExpr::new_with_target_field(
345+
Arc::clone(&expr),
346+
Arc::clone(&target_field),
347+
cast_options.clone(),
348+
);
349+
350+
if expr_type == *cast_type && candidate.preserves_child_field_semantics(input_schema)? {
341351
return Ok(Arc::clone(&expr));
342352
}
343353

@@ -353,11 +363,7 @@ pub(crate) fn cast_with_target_field_and_options(
353363
if !is_valid_cast {
354364
not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}")
355365
} else {
356-
Ok(Arc::new(CastExpr::new_with_target_field(
357-
expr,
358-
target_field,
359-
cast_options,
360-
)))
366+
Ok(Arc::new(candidate))
361367
}
362368
}
363369

@@ -377,7 +383,7 @@ pub fn cast(
377383
mod tests {
378384
use super::*;
379385

380-
use crate::expressions::column::col;
386+
use crate::expressions::{Column, column::col};
381387

382388
use arrow::{
383389
array::{
@@ -921,6 +927,51 @@ mod tests {
921927
Ok(())
922928
}
923929

930+
#[test]
931+
fn field_aware_same_type_cast_preserves_explicit_target_field() -> Result<()> {
932+
let schema = Schema::new(vec![Field::new("a", Int32, false)]);
933+
let expr = cast_with_target_field_and_options(
934+
col("a", &schema)?,
935+
&schema,
936+
Arc::new(Field::new("logical_a", Int32, true).with_metadata(HashMap::from([(
937+
"target_meta".to_string(),
938+
"1".to_string(),
939+
)]))),
940+
None,
941+
)?;
942+
943+
let cast_expr = expr
944+
.as_any()
945+
.downcast_ref::<CastExpr>()
946+
.expect("explicit same-type target should preserve CastExpr");
947+
let field = cast_expr.return_field(&schema)?;
948+
949+
assert_eq!(field.name(), "logical_a");
950+
assert!(field.is_nullable());
951+
assert_eq!(
952+
field.metadata().get("target_meta").map(String::as_str),
953+
Some("1")
954+
);
955+
956+
Ok(())
957+
}
958+
959+
#[test]
960+
fn default_same_type_cast_is_elided() -> Result<()> {
961+
let schema = Schema::new(vec![Field::new("a", Int32, false)]);
962+
let expr = cast_with_target_field_and_options(
963+
col("a", &schema)?,
964+
&schema,
965+
Int32.into_nullable_field_ref(),
966+
None,
967+
)?;
968+
969+
assert!(expr.as_any().downcast_ref::<Column>().is_some());
970+
assert!(expr.as_any().downcast_ref::<CastExpr>().is_none());
971+
972+
Ok(())
973+
}
974+
924975
#[test]
925976
fn field_aware_cast_nullable_prefers_child_nullability() -> Result<()> {
926977
// When the child expression is nullable the cast must be treated as

datafusion/physical-expr/src/planner.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ mod tests {
542542
}
543543

544544
#[test]
545-
fn test_create_physical_expr_same_type_cast_is_elided() -> Result<()> {
545+
fn test_create_physical_expr_same_type_cast_preserves_explicit_target_field() -> Result<()> {
546546
let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
547547
let df_schema = DFSchema::try_from(schema.clone())?;
548548
let target_field = Arc::new(
@@ -555,6 +555,25 @@ mod tests {
555555
Arc::clone(&target_field),
556556
));
557557

558+
let physical_expr =
559+
create_physical_expr(&expr, &df_schema, &ExecutionProps::new())?;
560+
561+
let cast_expr = physical_expr
562+
.as_any()
563+
.downcast_ref::<CastExpr>()
564+
.expect("same-type cast with explicit field should preserve CastExpr");
565+
assert_eq!(cast_expr.target_field(), &target_field);
566+
assert_eq!(physical_expr.return_field(&schema)?.name(), "same_type_cast");
567+
568+
Ok(())
569+
}
570+
571+
#[test]
572+
fn test_create_physical_expr_default_same_type_cast_is_elided() -> Result<()> {
573+
let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
574+
let df_schema = DFSchema::try_from(schema.clone())?;
575+
let expr = Expr::Cast(Cast::new(Box::new(col("value")), DataType::Int32));
576+
558577
let physical_expr =
559578
create_physical_expr(&expr, &df_schema, &ExecutionProps::new())?;
560579

0 commit comments

Comments
 (0)