Skip to content

Commit e8a93bb

Browse files
authored
feat: automatically cast ListView to List for UDFs (#21855)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #21777 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> First step in supporting `ListView` types in UDFs. Previously the UDFs would error if trying to pass a `ListView` type to nested UDFs. Now we can pass in `ListView` inputs and they are cast to `List`/`LargeList` (similar to `FixedSizeList` if that coercion mode is enabled). ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> In type coercion, for array signatures automatically coerce any `ListView`/`LargeListView` types to `List`/`LargeList` respectively. Also add `ListView` support to various array util functions. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added SLTs. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent b9cf885 commit e8a93bb

8 files changed

Lines changed: 174 additions & 23 deletions

File tree

datafusion/common/src/utils/mod.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,17 @@ pub fn base_type(data_type: &DataType) -> DataType {
599599
match data_type {
600600
DataType::List(field)
601601
| DataType::LargeList(field)
602+
| DataType::ListView(field)
603+
| DataType::LargeListView(field)
602604
| DataType::FixedSizeList(field, _) => base_type(field.data_type()),
603605
_ => data_type.to_owned(),
604606
}
605607
}
606608

609+
// TODO: Modify this to also allow specifying how listviews should be treated.
610+
// For example if cast to List (default) or maintain as ListView (requires
611+
// function to implement support for ListViews)
612+
// https://github.com/apache/datafusion/issues/21777
607613
/// Information about how to coerce lists.
608614
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
609615
pub enum ListCoercion {
@@ -657,6 +663,19 @@ pub fn coerced_type_with_base_type_only(
657663
*len,
658664
)
659665
}
666+
(DataType::ListView(field), _) => {
667+
let field_type = coerced_type_with_base_type_only(
668+
field.data_type(),
669+
base_type,
670+
array_coercion,
671+
);
672+
673+
DataType::ListView(Arc::new(Field::new(
674+
field.name(),
675+
field_type,
676+
field.is_nullable(),
677+
)))
678+
}
660679
(DataType::LargeList(field), _) => {
661680
let field_type = coerced_type_with_base_type_only(
662681
field.data_type(),
@@ -670,6 +689,19 @@ pub fn coerced_type_with_base_type_only(
670689
field.is_nullable(),
671690
)))
672691
}
692+
(DataType::LargeListView(field), _) => {
693+
let field_type = coerced_type_with_base_type_only(
694+
field.data_type(),
695+
base_type,
696+
array_coercion,
697+
);
698+
699+
DataType::LargeListView(Arc::new(Field::new(
700+
field.name(),
701+
field_type,
702+
field.is_nullable(),
703+
)))
704+
}
673705

674706
_ => base_type.clone(),
675707
}
@@ -687,6 +719,15 @@ pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType {
687719
field.is_nullable(),
688720
)))
689721
}
722+
DataType::ListView(field) => {
723+
let field_type = coerced_fixed_size_list_to_list(field.data_type());
724+
725+
DataType::ListView(Arc::new(Field::new(
726+
field.name(),
727+
field_type,
728+
field.is_nullable(),
729+
)))
730+
}
690731
DataType::LargeList(field) => {
691732
let field_type = coerced_fixed_size_list_to_list(field.data_type());
692733

@@ -696,6 +737,15 @@ pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType {
696737
field.is_nullable(),
697738
)))
698739
}
740+
DataType::LargeListView(field) => {
741+
let field_type = coerced_fixed_size_list_to_list(field.data_type());
742+
743+
DataType::LargeListView(Arc::new(Field::new(
744+
field.name(),
745+
field_type,
746+
field.is_nullable(),
747+
)))
748+
}
699749

700750
_ => data_type.clone(),
701751
}
@@ -706,6 +756,8 @@ pub fn list_ndims(data_type: &DataType) -> u64 {
706756
match data_type {
707757
DataType::List(field)
708758
| DataType::LargeList(field)
759+
| DataType::ListView(field)
760+
| DataType::LargeListView(field)
709761
| DataType::FixedSizeList(field, _) => 1 + list_ndims(field.data_type()),
710762
_ => 0,
711763
}

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,12 +537,12 @@ fn get_valid_types(
537537
element_types.push(DataType::Null);
538538
nested_item_nullability.push(None);
539539
}
540-
DataType::List(field) => {
540+
DataType::List(field) | DataType::ListView(field) => {
541541
element_types.push(field.data_type().clone());
542542
nested_item_nullability.push(Some(field.is_nullable()));
543543
fixed_size = false;
544544
}
545-
DataType::LargeList(field) => {
545+
DataType::LargeList(field) | DataType::LargeListView(field) => {
546546
element_types.push(field.data_type().clone());
547547
nested_item_nullability.push(Some(field.is_nullable()));
548548
large_list = true;
@@ -580,6 +580,8 @@ fn get_valid_types(
580580
ArrayFunctionArgument::Index => DataType::Int64,
581581
ArrayFunctionArgument::String => DataType::Utf8,
582582
ArrayFunctionArgument::Element => element_type.clone(),
583+
// TODO: support maintaining ListView types here
584+
// https://github.com/apache/datafusion/issues/21777
583585
ArrayFunctionArgument::Array => {
584586
if current_type.is_null() {
585587
DataType::Null
@@ -611,6 +613,8 @@ fn get_valid_types(
611613
match array_type {
612614
DataType::List(_)
613615
| DataType::LargeList(_)
616+
| DataType::ListView(_)
617+
| DataType::LargeListView(_)
614618
| DataType::FixedSizeList(_, _) => {
615619
let array_type = coerced_fixed_size_list_to_list(array_type);
616620
Some(array_type)
@@ -1044,7 +1048,7 @@ fn coerced_from<'a>(
10441048

10451049
// Only accept list and largelist with the same number of dimensions unless the type is Null.
10461050
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this
1047-
(List(_) | LargeList(_), _)
1051+
(List(_) | LargeList(_) | ListView(_) | LargeListView(_), _)
10481052
if base_type(type_from).is_null()
10491053
|| list_ndims(type_from) == list_ndims(type_into) =>
10501054
{
@@ -1495,6 +1499,54 @@ mod tests {
14951499
]]
14961500
);
14971501

1502+
let data_types = vec![
1503+
DataType::ListView(Field::new_list_field(DataType::Int32, true).into()),
1504+
DataType::new_list(DataType::Int32, true),
1505+
];
1506+
assert_eq!(
1507+
get_valid_types(function, &signature.type_signature, &data_types)?,
1508+
vec![vec![
1509+
DataType::new_list(DataType::Int32, true),
1510+
DataType::new_list(DataType::Int32, true),
1511+
]]
1512+
);
1513+
1514+
let data_types = vec![
1515+
DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()),
1516+
DataType::new_list(DataType::Int32, true),
1517+
];
1518+
assert_eq!(
1519+
get_valid_types(function, &signature.type_signature, &data_types)?,
1520+
vec![vec![
1521+
DataType::new_large_list(DataType::Int32, true),
1522+
DataType::new_large_list(DataType::Int32, true),
1523+
]]
1524+
);
1525+
1526+
let data_types = vec![
1527+
DataType::ListView(Field::new_list_field(DataType::Int32, true).into()),
1528+
DataType::ListView(Field::new_list_field(DataType::Int32, true).into()),
1529+
];
1530+
assert_eq!(
1531+
get_valid_types(function, &signature.type_signature, &data_types)?,
1532+
vec![vec![
1533+
DataType::new_list(DataType::Int32, true),
1534+
DataType::new_list(DataType::Int32, true),
1535+
]]
1536+
);
1537+
1538+
let data_types = vec![
1539+
DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()),
1540+
DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()),
1541+
];
1542+
assert_eq!(
1543+
get_valid_types(function, &signature.type_signature, &data_types)?,
1544+
vec![vec![
1545+
DataType::new_large_list(DataType::Int32, true),
1546+
DataType::new_large_list(DataType::Int32, true),
1547+
]]
1548+
);
1549+
14981550
Ok(())
14991551
}
15001552

datafusion/functions-nested/src/utils.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ use arrow::array::{
2626
};
2727
use arrow::buffer::OffsetBuffer;
2828
use datafusion_common::cast::{
29-
as_fixed_size_list_array, as_large_list_array, as_list_array,
29+
as_fixed_size_list_array, as_large_list_array, as_large_list_view_array,
30+
as_list_array, as_list_view_array,
3031
};
3132
use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
3233

@@ -243,6 +244,14 @@ pub(crate) fn compute_array_dims(
243244
value = as_large_list_array(&value)?.value(0);
244245
res.push(Some(value.len() as u64));
245246
}
247+
DataType::ListView(_) => {
248+
value = as_list_view_array(&value)?.value(0);
249+
res.push(Some(value.len() as u64));
250+
}
251+
DataType::LargeListView(_) => {
252+
value = as_large_list_view_array(&value)?.value(0);
253+
res.push(Some(value.len() as u64));
254+
}
246255
DataType::FixedSizeList(..) => {
247256
value = as_fixed_size_list_array(&value)?.value(0);
248257
res.push(Some(value.len() as u64));

datafusion/sqllogictest/test_files/array/array_empty.slt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,5 +176,15 @@ select list_empty(arrow_cast(make_array(NULL), 'LargeList(Int64)'));
176176
----
177177
false
178178

179+
query B
180+
select empty(arrow_cast([1], 'ListView(Int64)'));
181+
----
182+
false
183+
184+
query B
185+
select empty(arrow_cast([1], 'LargeListView(Int64)'));
186+
----
187+
false
188+
179189

180190
include ./cleanup.slt.part

datafusion/sqllogictest/test_files/array/array_pop.slt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,5 +302,25 @@ select array_pop_front(arrow_cast(make_array(make_array(1, 2, 3)), 'FixedSizeLis
302302
----
303303
[]
304304

305+
query ?
306+
select array_pop_back(arrow_cast([1, 2], 'ListView(Int64)'));
307+
----
308+
[1]
309+
310+
query ?
311+
select array_pop_front(arrow_cast([1, 2], 'ListView(Int64)'));
312+
----
313+
[2]
314+
315+
query ?
316+
select array_pop_back(arrow_cast([1, 2], 'LargeListView(Int64)'));
317+
----
318+
[1]
319+
320+
query ?
321+
select array_pop_front(arrow_cast([1, 2], 'LargeListView(Int64)'));
322+
----
323+
[2]
324+
305325

306326
include ./cleanup.slt.part

datafusion/sqllogictest/test_files/array/array_reverse.slt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,15 @@ select array_contains(a, b) from array_has order by 1 nulls last;
164164
true
165165
NULL
166166

167-
# Expected output (once supported):
168-
# ----
169-
# [5, 4, 3, 2, 1]
170-
query error
171-
select array_reverse(arrow_cast(make_array(1, 2, 3, 4, 5), 'ListView(Int64)'));
167+
query ?
168+
select array_reverse(arrow_cast([1, 2, 3, 4, 5], 'ListView(Int64)'));
169+
----
170+
[5, 4, 3, 2, 1]
172171

172+
query ?
173+
select array_reverse(arrow_cast([1, 2, 3, 4, 5], 'LargeListView(Int64)'));
174+
----
175+
[5, 4, 3, 2, 1]
173176

174177
statement ok
175178
drop table test_create_array_table;

datafusion/sqllogictest/test_files/array/array_slice.slt

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,11 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0,
7979
----
8080
[1, 2, 3, 4, 5] [h, e, l, l, o]
8181

82-
# TODO make error message nicer: https://github.com/apache/datafusion/issues/19004
83-
# Expected output (once supported):
84-
# ----
85-
# [1, 2, 3, 4, 5] [h, e, l, l, o]
86-
query error Failed to coerce arguments to satisfy a call to 'array_slice' function:
87-
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'ListView(Int64)'), 0, 6),
88-
array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'ListView(Utf8)'), 0, 5);
82+
query ??
83+
select array_slice(arrow_cast([1, 2, 3, 4, 5], 'ListView(Int64)'), 0, 6),
84+
array_slice(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'ListView(Utf8)'), 0, 5);
85+
----
86+
[1, 2, 3, 4, 5] [h, e, l, l, o]
8987

9088
query ??
9189
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 0, 6),
@@ -126,13 +124,11 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2,
126124
----
127125
[2, 3, 4, 5] [l, l, o]
128126

129-
# TODO: Enable once array_slice supports LargeListView types.
130-
# Expected output (once supported):
131-
# ----
132-
# [2, 3, 4, 5] [l, l, o]
133-
query error Failed to coerce arguments to satisfy a call to 'array_slice' function:
134-
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeListView(Int64)'), 2, 6),
135-
array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeListView(Utf8)'), 3, 7);
127+
query ??
128+
select array_slice(arrow_cast([1, 2, 3, 4, 5], 'LargeListView(Int64)'), 2, 6),
129+
array_slice(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeListView(Utf8)'), 3, 7);
130+
----
131+
[2, 3, 4, 5] [l, l, o]
136132

137133

138134
# array_slice scalar function #6 (with positive indexes; nested array)

datafusion/sqllogictest/test_files/array/array_sort.slt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,15 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3
254254
----
255255
[NULL, NULL, -5, 1, 3, 5] [NULL, 1, 2, 3] [NULL, 3, 2, 1]
256256

257+
query ?
258+
select array_sort(arrow_cast([1, 3, null, 5, NULL, -5], 'ListView(Int64)'));
259+
----
260+
[NULL, NULL, -5, 1, 3, 5]
261+
262+
query ?
263+
select array_sort(arrow_cast([1, 3, null, 5, NULL, -5], 'LargeListView(Int64)'));
264+
----
265+
[NULL, NULL, -5, 1, 3, 5]
257266

258267

259268
include ./cleanup.slt.part

0 commit comments

Comments
 (0)