Skip to content

Commit 5981d62

Browse files
authored
chore: remove datatype check functions in favour of upstream versions (#20104)
We now have [`DataType::is_string`](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_string) and [`DataType::is_decimal`](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_decimal) so make the most of them
1 parent c3eb9ef commit 5981d62

5 files changed

Lines changed: 22 additions & 62 deletions

File tree

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,6 @@ impl<'a> BinaryTypeCoercer<'a> {
351351

352352
// TODO Move the rest inside of BinaryTypeCoercer
353353

354-
fn is_decimal(data_type: &DataType) -> bool {
355-
matches!(
356-
data_type,
357-
DataType::Decimal32(..)
358-
| DataType::Decimal64(..)
359-
| DataType::Decimal128(..)
360-
| DataType::Decimal256(..)
361-
)
362-
}
363-
364354
/// Returns true if both operands are Date types (Date32 or Date64)
365355
/// Used to detect Date - Date operations which should return Int64 (days difference)
366356
fn is_date_minus_date(lhs: &DataType, rhs: &DataType) -> bool {
@@ -402,8 +392,8 @@ fn math_decimal_coercion(
402392
}
403393
// Cross-variant decimal coercion - choose larger variant with appropriate precision/scale
404394
(lhs, rhs)
405-
if is_decimal(lhs)
406-
&& is_decimal(rhs)
395+
if lhs.is_decimal()
396+
&& rhs.is_decimal()
407397
&& std::mem::discriminant(lhs) != std::mem::discriminant(rhs) =>
408398
{
409399
let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?;
@@ -1018,17 +1008,17 @@ pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Data
10181008
match (lhs_type, rhs_type) {
10191009
// Same decimal types
10201010
(lhs_type, rhs_type)
1021-
if is_decimal(lhs_type)
1022-
&& is_decimal(rhs_type)
1011+
if lhs_type.is_decimal()
1012+
&& rhs_type.is_decimal()
10231013
&& std::mem::discriminant(lhs_type)
10241014
== std::mem::discriminant(rhs_type) =>
10251015
{
10261016
get_wider_decimal_type(lhs_type, rhs_type)
10271017
}
10281018
// Mismatched decimal types
10291019
(lhs_type, rhs_type)
1030-
if is_decimal(lhs_type)
1031-
&& is_decimal(rhs_type)
1020+
if lhs_type.is_decimal()
1021+
&& rhs_type.is_decimal()
10321022
&& std::mem::discriminant(lhs_type)
10331023
!= std::mem::discriminant(rhs_type) =>
10341024
{

datafusion/expr/src/type_coercion/mod.rs

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
5858
)
5959
}
6060

61-
/// Determine whether the given data type `dt` is `Null`.
62-
pub fn is_null(dt: &DataType) -> bool {
63-
*dt == DataType::Null
64-
}
65-
6661
/// Determine whether the given data type `dt` is a `Timestamp`.
6762
pub fn is_timestamp(dt: &DataType) -> bool {
6863
matches!(dt, DataType::Timestamp(_, _))
@@ -80,22 +75,3 @@ pub fn is_datetime(dt: &DataType) -> bool {
8075
DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _)
8176
)
8277
}
83-
84-
/// Determine whether the given data type `dt` is a `Utf8` or `Utf8View` or `LargeUtf8`.
85-
pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) -> bool {
86-
matches!(
87-
dt,
88-
DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8
89-
)
90-
}
91-
92-
/// Determine whether the given data type `dt` is a `Decimal`.
93-
pub fn is_decimal(dt: &DataType) -> bool {
94-
matches!(
95-
dt,
96-
DataType::Decimal32(_, _)
97-
| DataType::Decimal64(_, _)
98-
| DataType::Decimal128(_, _)
99-
| DataType::Decimal256(_, _)
100-
)
101-
}

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ use datafusion_expr::expr_schema::cast_subquery;
4343
use datafusion_expr::logical_plan::Subquery;
4444
use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
4545
use datafusion_expr::type_coercion::functions::{UDFCoercionExt, fields_with_udf};
46+
use datafusion_expr::type_coercion::is_datetime;
4647
use datafusion_expr::type_coercion::other::{
4748
get_coerce_type_for_case_expression, get_coerce_type_for_list,
4849
};
49-
use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
5050
use datafusion_expr::utils::merge_schema;
5151
use datafusion_expr::{
5252
Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, Union,
@@ -513,10 +513,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
513513
.data;
514514
let expr_type = expr.get_type(self.schema)?;
515515
let subquery_type = new_plan.schema().field(0).data_type();
516-
if (expr_type.is_numeric()
517-
&& is_utf8_or_utf8view_or_large_utf8(subquery_type))
518-
|| (subquery_type.is_numeric()
519-
&& is_utf8_or_utf8view_or_large_utf8(&expr_type))
516+
if (expr_type.is_numeric() && subquery_type.is_string())
517+
|| (subquery_type.is_numeric() && expr_type.is_string())
520518
{
521519
return plan_err!(
522520
"expr type {expr_type} can't cast to {subquery_type} in SetComparison"
@@ -890,12 +888,15 @@ fn coerce_frame_bound(
890888

891889
fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
892890
if col_type.is_numeric()
893-
|| is_utf8_or_utf8view_or_large_utf8(col_type)
894-
|| matches!(col_type, DataType::List(_))
895-
|| matches!(col_type, DataType::LargeList(_))
896-
|| matches!(col_type, DataType::FixedSizeList(_, _))
897-
|| matches!(col_type, DataType::Null)
898-
|| matches!(col_type, DataType::Boolean)
891+
|| col_type.is_string()
892+
|| col_type.is_null()
893+
|| matches!(
894+
col_type,
895+
DataType::List(_)
896+
| DataType::LargeList(_)
897+
| DataType::FixedSizeList(_, _)
898+
| DataType::Boolean
899+
)
899900
{
900901
Ok(col_type.clone())
901902
} else if is_datetime(col_type) {

datafusion/physical-expr/src/expressions/negative.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use datafusion_expr::statistics::Distribution::{
3737
};
3838
use datafusion_expr::{
3939
ColumnarValue,
40-
type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp},
40+
type_coercion::{is_interval, is_signed_numeric, is_timestamp},
4141
};
4242

4343
/// Negative expression
@@ -190,7 +190,7 @@ pub fn negative(
190190
input_schema: &Schema,
191191
) -> Result<Arc<dyn PhysicalExpr>> {
192192
let data_type = arg.data_type(input_schema)?;
193-
if is_null(&data_type) {
193+
if data_type.is_null() {
194194
Ok(arg)
195195
} else if !is_signed_numeric(&data_type)
196196
&& !is_interval(&data_type)

datafusion/pruning/src/pruning_predicate.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,13 +1205,6 @@ fn is_compare_op(op: Operator) -> bool {
12051205
)
12061206
}
12071207

1208-
fn is_string_type(data_type: &DataType) -> bool {
1209-
matches!(
1210-
data_type,
1211-
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
1212-
)
1213-
}
1214-
12151208
// The pruning logic is based on the comparing the min/max bounds.
12161209
// Must make sure the two type has order.
12171210
// For example, casts from string to numbers is not correct.
@@ -1233,7 +1226,7 @@ fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Re
12331226
// If both types are strings or both are not strings (number, timestamp, etc)
12341227
// then we can compare them.
12351228
// PruningPredicate does not support casting of strings to numbers and such.
1236-
if is_string_type(from_type) == is_string_type(to_type) {
1229+
if from_type.is_string() == to_type.is_string() {
12371230
Ok(())
12381231
} else {
12391232
plan_err!(
@@ -4681,7 +4674,7 @@ mod tests {
46814674
true,
46824675
// s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep)
46834676
true,
4684-
// s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate
4677+
// s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate
46854678
// original (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}")
46864679
true,
46874680
];

0 commit comments

Comments
 (0)