Skip to content

Commit 9d2790d

Browse files
authored
Merge branch 'main' into rm-aggregates-integers
2 parents 2fd9259 + abf8f61 commit 9d2790d

10 files changed

Lines changed: 663 additions & 183 deletions

File tree

datafusion/common/src/nested_struct.rs

Lines changed: 376 additions & 50 deletions
Large diffs are not rendered by default.

datafusion/common/src/scalar/mod.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4138,20 +4138,16 @@ impl ScalarValue {
41384138

41394139
let scalar_array = self.to_array()?;
41404140

4141-
// For struct types, use name-based casting logic that matches fields by name
4142-
// and recursively casts nested structs. The field name wrapper is arbitrary
4143-
// since cast_column only uses the DataType::Struct field definitions inside.
4144-
let cast_arr = match target_type {
4145-
DataType::Struct(_) => {
4146-
// Field name is unused; only the struct's inner field names matter
4147-
let target_field = Field::new("_", target_type.clone(), true);
4148-
crate::nested_struct::cast_column(
4149-
&scalar_array,
4150-
&target_field,
4151-
cast_options,
4152-
)?
4153-
}
4154-
_ => cast_with_options(&scalar_array, target_type, cast_options)?,
4141+
// For types that contain structs (including nested inside Lists, Dictionaries,
4142+
// etc.), use name-based casting logic that matches struct fields by name and
4143+
// recursively casts nested structs.
4144+
let cast_arr = if crate::nested_struct::requires_nested_struct_cast(
4145+
scalar_array.data_type(),
4146+
target_type,
4147+
) {
4148+
crate::nested_struct::cast_column(&scalar_array, target_type, cast_options)?
4149+
} else {
4150+
cast_with_options(&scalar_array, target_type, cast_options)?
41554151
};
41564152

41574153
ScalarValue::try_from_array(&cast_arr, 0)

datafusion/expr-common/src/columnar_value.rs

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use arrow::{
2121
array::{Array, ArrayRef, Date32Array, Date64Array, NullArray},
2222
compute::{CastOptions, kernels, max, min},
23-
datatypes::{DataType, Field},
23+
datatypes::DataType,
2424
util::pretty::pretty_format_columns,
2525
};
2626
use datafusion_common::internal_datafusion_err;
@@ -313,24 +313,18 @@ fn cast_array_by_name(
313313
return Ok(Arc::clone(array));
314314
}
315315

316-
match cast_type {
317-
DataType::Struct(_) => {
318-
// Field name is unused; only the struct's inner field names matter
319-
let target_field = Field::new("_", cast_type.clone(), true);
320-
datafusion_common::nested_struct::cast_column(
321-
array,
322-
&target_field,
323-
cast_options,
324-
)
325-
}
326-
_ => {
327-
ensure_date_array_timestamp_bounds(array, cast_type)?;
328-
Ok(kernels::cast::cast_with_options(
329-
array,
330-
cast_type,
331-
cast_options,
332-
)?)
333-
}
316+
if datafusion_common::nested_struct::requires_nested_struct_cast(
317+
array.data_type(),
318+
cast_type,
319+
) {
320+
datafusion_common::nested_struct::cast_column(array, cast_type, cast_options)
321+
} else {
322+
ensure_date_array_timestamp_bounds(array, cast_type)?;
323+
Ok(kernels::cast::cast_with_options(
324+
array,
325+
cast_type,
326+
cast_options,
327+
)?)
334328
}
335329
}
336330

datafusion/functions-nested/src/remove.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ make_udf_expr_and_func!(
4040
ArrayRemove,
4141
array_remove,
4242
array element,
43-
"removes the first element from the array equal to the given value.",
43+
"removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
4444
array_remove_udf
4545
);
4646

4747
#[user_doc(
4848
doc_section(label = "Array Functions"),
49-
description = "Removes the first element from the array equal to the given value.",
49+
description = "Removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
5050
syntax_example = "array_remove(array, element)",
5151
sql_example = r#"```sql
5252
> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
@@ -55,6 +55,13 @@ make_udf_expr_and_func!(
5555
+----------------------------------------------+
5656
| [1, 2, 3, 2, 1, 4] |
5757
+----------------------------------------------+
58+
59+
> select array_remove([1, 2, NULL, 2, 4], 2);
60+
+---------------------------------------------------+
61+
| array_remove(List([1,2,NULL,2,4]),Int64(2)) |
62+
+---------------------------------------------------+
63+
| [1, NULL, 2, 4] |
64+
+---------------------------------------------------+
5865
```"#,
5966
argument(
6067
name = "array",
@@ -127,21 +134,28 @@ make_udf_expr_and_func!(
127134
ArrayRemoveN,
128135
array_remove_n,
129136
array element max,
130-
"removes the first `max` elements from the array equal to the given value.",
137+
"removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
131138
array_remove_n_udf
132139
);
133140

134141
#[user_doc(
135142
doc_section(label = "Array Functions"),
136-
description = "Removes the first `max` elements from the array equal to the given value.",
137-
syntax_example = "array_remove_n(array, element, max))",
143+
description = "Removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
144+
syntax_example = "array_remove_n(array, element, max)",
138145
sql_example = r#"```sql
139146
> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
140147
+---------------------------------------------------------+
141148
| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
142149
+---------------------------------------------------------+
143150
| [1, 3, 2, 1, 4] |
144151
+---------------------------------------------------------+
152+
153+
> select array_remove_n([1, 2, NULL, 2, 4], 2, 2);
154+
+----------------------------------------------------------+
155+
| array_remove_n(List([1,2,NULL,2,4]),Int64(2),Int64(2)) |
156+
+----------------------------------------------------------+
157+
| [1, NULL, 4] |
158+
+----------------------------------------------------------+
145159
```"#,
146160
argument(
147161
name = "array",
@@ -219,13 +233,13 @@ make_udf_expr_and_func!(
219233
ArrayRemoveAll,
220234
array_remove_all,
221235
array element,
222-
"removes all elements from the array equal to the given value.",
236+
"removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
223237
array_remove_all_udf
224238
);
225239

226240
#[user_doc(
227241
doc_section(label = "Array Functions"),
228-
description = "Removes all elements from the array equal to the given value.",
242+
description = "Removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
229243
syntax_example = "array_remove_all(array, element)",
230244
sql_example = r#"```sql
231245
> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
@@ -234,6 +248,13 @@ make_udf_expr_and_func!(
234248
+--------------------------------------------------+
235249
| [1, 3, 1, 4] |
236250
+--------------------------------------------------+
251+
252+
> select array_remove_all([1, 2, NULL, 2, 4], 2);
253+
+-----------------------------------------------------+
254+
| array_remove_all(List([1,2,NULL,2,4]),Int64(2)) |
255+
+-----------------------------------------------------+
256+
| [1, NULL, 4] |
257+
+-----------------------------------------------------+
237258
```"#,
238259
argument(
239260
name = "array",

datafusion/physical-expr-adapter/src/schema_rewriter.rs

Lines changed: 149 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@ use std::hash::Hash;
2525
use std::sync::Arc;
2626

2727
use arrow::array::RecordBatch;
28-
use arrow::compute::can_cast_types;
2928
use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef};
3029
use datafusion_common::{
31-
Result, ScalarValue, exec_err,
30+
DataFusionError, Result, ScalarValue, exec_err,
3231
metadata::FieldMetadata,
33-
nested_struct::validate_struct_compatibility,
32+
nested_struct::validate_data_type_compatibility,
3433
tree_node::{Transformed, TransformedResult, TreeNode},
3534
};
3635
use datafusion_functions::core::getfield::GetFieldFunc;
@@ -487,31 +486,18 @@ impl DefaultPhysicalExprAdapterRewriter {
487486
physical_field: FieldRef,
488487
logical_field: &Field,
489488
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
490-
// For struct types, use validate_struct_compatibility which handles:
491-
// - Missing fields in source (filled with nulls)
492-
// - Extra fields in source (ignored)
493-
// - Recursive validation of nested structs
494-
// For non-struct types, use Arrow's can_cast_types
495-
match (physical_field.data_type(), logical_field.data_type()) {
496-
(DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => {
497-
validate_struct_compatibility(
498-
physical_fields.as_ref(),
499-
logical_fields.as_ref(),
500-
)?;
501-
}
502-
_ => {
503-
let is_compatible =
504-
can_cast_types(physical_field.data_type(), logical_field.data_type());
505-
if !is_compatible {
506-
return exec_err!(
507-
"Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
489+
validate_data_type_compatibility(
490+
column.name(),
491+
physical_field.data_type(),
492+
logical_field.data_type(),
493+
)
494+
.map_err(|e|
495+
DataFusionError::Execution(format!(
496+
"Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type): {e}",
508497
column.name(),
509498
physical_field.data_type(),
510499
logical_field.data_type()
511-
);
512-
}
513-
}
514-
}
500+
)))?;
515501

516502
let cast_expr = Arc::new(CastColumnExpr::new(
517503
Arc::new(column),
@@ -663,8 +649,8 @@ impl BatchAdapter {
663649
mod tests {
664650
use super::*;
665651
use arrow::array::{
666-
BooleanArray, Int32Array, Int64Array, RecordBatch, RecordBatchOptions,
667-
StringArray, StringViewArray, StructArray,
652+
Array, BooleanArray, GenericListArray, Int32Array, Int64Array, RecordBatch,
653+
RecordBatchOptions, StringArray, StringViewArray, StructArray,
668654
};
669655
use arrow::datatypes::{Fields, Schema};
670656
use datafusion_common::{assert_contains, record_batch};
@@ -1282,6 +1268,142 @@ mod tests {
12821268
assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]);
12831269
}
12841270

1271+
/// Test that List<Struct> columns are properly adapted with struct evolution.
1272+
#[test]
1273+
fn test_adapt_list_struct_batches() {
1274+
// Physical: List<{id: Int32, name: Utf8}>
1275+
let physical_struct_fields: Fields = vec![
1276+
Field::new("id", DataType::Int32, false),
1277+
Field::new("name", DataType::Utf8, true),
1278+
]
1279+
.into();
1280+
1281+
let struct_array = StructArray::new(
1282+
physical_struct_fields.clone(),
1283+
vec![
1284+
Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
1285+
Arc::new(StringArray::from(vec![
1286+
Some("alice"),
1287+
None,
1288+
Some("charlie"),
1289+
])) as _,
1290+
],
1291+
None,
1292+
);
1293+
1294+
// One list element per row
1295+
let item_field = Arc::new(Field::new(
1296+
"item",
1297+
DataType::Struct(physical_struct_fields.clone()),
1298+
true,
1299+
));
1300+
let offsets =
1301+
arrow::buffer::OffsetBuffer::from_lengths(vec![1usize; struct_array.len()]);
1302+
let list_array = GenericListArray::<i32>::new(
1303+
item_field,
1304+
offsets,
1305+
Arc::new(struct_array),
1306+
None,
1307+
);
1308+
1309+
let physical_schema = Arc::new(Schema::new(vec![Field::new(
1310+
"data",
1311+
DataType::List(Arc::new(Field::new(
1312+
"item",
1313+
DataType::Struct(physical_struct_fields),
1314+
true,
1315+
))),
1316+
false,
1317+
)]));
1318+
1319+
let physical_batch = RecordBatch::try_new(
1320+
Arc::clone(&physical_schema),
1321+
vec![Arc::new(list_array)],
1322+
)
1323+
.unwrap();
1324+
1325+
// Logical: List<{id: Int64, name: Utf8View, extra: Boolean}>
1326+
let logical_struct_fields: Fields = vec![
1327+
Field::new("id", DataType::Int64, false),
1328+
Field::new("name", DataType::Utf8View, true),
1329+
Field::new("extra", DataType::Boolean, true),
1330+
]
1331+
.into();
1332+
1333+
let logical_schema = Arc::new(Schema::new(vec![Field::new(
1334+
"data",
1335+
DataType::List(Arc::new(Field::new(
1336+
"item",
1337+
DataType::Struct(logical_struct_fields.clone()),
1338+
true,
1339+
))),
1340+
false,
1341+
)]));
1342+
1343+
let projection = vec![col("data", &logical_schema).unwrap()];
1344+
1345+
let factory = DefaultPhysicalExprAdapterFactory;
1346+
let adapter = factory
1347+
.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema))
1348+
.unwrap();
1349+
1350+
let adapted_projection = projection
1351+
.into_iter()
1352+
.map(|expr| adapter.rewrite(expr).unwrap())
1353+
.collect_vec();
1354+
1355+
let adapted_schema = Arc::new(Schema::new(
1356+
adapted_projection
1357+
.iter()
1358+
.map(|expr| expr.return_field(&physical_schema).unwrap())
1359+
.collect_vec(),
1360+
));
1361+
1362+
let res = batch_project(
1363+
adapted_projection,
1364+
&physical_batch,
1365+
Arc::clone(&adapted_schema),
1366+
)
1367+
.unwrap();
1368+
1369+
assert_eq!(res.num_columns(), 1);
1370+
1371+
let result_list = res
1372+
.column(0)
1373+
.as_any()
1374+
.downcast_ref::<GenericListArray<i32>>()
1375+
.unwrap();
1376+
1377+
// Check each list element contains the evolved struct
1378+
assert_eq!(result_list.len(), 3);
1379+
let flat_structs = result_list
1380+
.values()
1381+
.as_any()
1382+
.downcast_ref::<StructArray>()
1383+
.unwrap();
1384+
1385+
let id_col = flat_structs.column_by_name("id").unwrap();
1386+
assert_eq!(id_col.data_type(), &DataType::Int64);
1387+
let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
1388+
assert_eq!(
1389+
id_values.iter().collect_vec(),
1390+
vec![Some(1), Some(2), Some(3)]
1391+
);
1392+
1393+
let name_col = flat_structs.column_by_name("name").unwrap();
1394+
assert_eq!(name_col.data_type(), &DataType::Utf8View);
1395+
let name_values = name_col.as_any().downcast_ref::<StringViewArray>().unwrap();
1396+
assert_eq!(
1397+
name_values.iter().collect_vec(),
1398+
vec![Some("alice"), None, Some("charlie")]
1399+
);
1400+
1401+
let extra_col = flat_structs.column_by_name("extra").unwrap();
1402+
assert_eq!(extra_col.data_type(), &DataType::Boolean);
1403+
let extra_values = extra_col.as_any().downcast_ref::<BooleanArray>().unwrap();
1404+
assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]);
1405+
}
1406+
12851407
#[test]
12861408
fn test_try_rewrite_struct_field_access() {
12871409
// Test the core logic of try_rewrite_struct_field_access

0 commit comments

Comments
 (0)