Skip to content

Commit 066504d

Browse files
authored
Merge branch 'main' into implement_groups_accumulator_count_distinct_primitive_types
2 parents 47110a8 + 5a427cb commit 066504d

14 files changed

Lines changed: 803 additions & 398 deletions

File tree

datafusion/expr-common/src/casts.rs

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,28 @@ pub fn is_supported_type(data_type: &DataType) -> bool {
5858
|| is_supported_binary_type(data_type)
5959
}
6060

61-
/// Returns true if unwrap_cast_in_comparison support this numeric type
61+
fn is_date_type(data_type: &DataType) -> bool {
62+
matches!(data_type, DataType::Date32 | DataType::Date64)
63+
}
64+
65+
/// Returns true when unwrapping a date/timestamp cast could change comparison
66+
/// semantics.
67+
///
68+
/// A `Date` stores only a calendar day, while a `Timestamp` stores a specific
69+
/// instant or wall-clock time. `Timestamp -> Date` is lossy because it drops the
70+
/// time-of-day. `Date -> Timestamp` is also lossy in this optimizer context
71+
/// because there is no unique inverse: converting a date to a timestamp has to
72+
/// invent a time component such as midnight.
73+
///
74+
/// For example, `CAST(ts AS DATE) = DATE '2024-01-01'` means "any timestamp
75+
/// during that day", but unwrapping it to `ts = TIMESTAMP '2024-01-01
76+
/// 00:00:00'` matches only midnight.
77+
fn is_lossy_temporal_cast(from_type: &DataType, to_type: &DataType) -> bool {
78+
(is_date_type(from_type) && to_type.is_temporal())
79+
|| (is_date_type(to_type) && from_type.is_temporal())
80+
}
81+
82+
/// Returns true if unwrap_cast_in_comparison supports this numeric type
6283
fn is_supported_numeric_type(data_type: &DataType) -> bool {
6384
matches!(
6485
data_type,
@@ -70,6 +91,8 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool {
7091
| DataType::Int16
7192
| DataType::Int32
7293
| DataType::Int64
94+
| DataType::Date32
95+
| DataType::Date64
7396
| DataType::Decimal32(_, _)
7497
| DataType::Decimal64(_, _)
7598
| DataType::Decimal128(_, _)
@@ -107,6 +130,10 @@ fn try_cast_numeric_literal(
107130
return None;
108131
}
109132

133+
if is_lossy_temporal_cast(&lit_data_type, target_type) {
134+
return None;
135+
}
136+
110137
let mul = match target_type {
111138
DataType::UInt8
112139
| DataType::UInt16
@@ -115,7 +142,9 @@ fn try_cast_numeric_literal(
115142
| DataType::Int8
116143
| DataType::Int16
117144
| DataType::Int32
118-
| DataType::Int64 => 1_i128,
145+
| DataType::Int64
146+
| DataType::Date32
147+
| DataType::Date64 => 1_i128,
119148
DataType::Timestamp(_, _) => 1_i128,
120149
DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32),
121150
DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32),
@@ -129,8 +158,8 @@ fn try_cast_numeric_literal(
129158
DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
130159
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
131160
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
132-
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
133-
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
161+
DataType::Int32 | DataType::Date32 => (i32::MIN as i128, i32::MAX as i128),
162+
DataType::Int64 | DataType::Date64 => (i64::MIN as i128, i64::MAX as i128),
134163
DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
135164
DataType::Decimal32(precision, _) => (
136165
// Different precision for decimal32 can store different range of value.
@@ -164,6 +193,8 @@ fn try_cast_numeric_literal(
164193
ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
165194
ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
166195
ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
196+
ScalarValue::Date32(Some(v)) => (*v as i128).checked_mul(mul),
197+
ScalarValue::Date64(Some(v)) => (*v as i128).checked_mul(mul),
167198
ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
168199
ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
169200
ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
@@ -241,6 +272,8 @@ fn try_cast_numeric_literal(
241272
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
242273
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
243274
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
275+
DataType::Date32 => ScalarValue::Date32(Some(value as i32)),
276+
DataType::Date64 => ScalarValue::Date64(Some(value as i64)),
244277
DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
245278
DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
246279
DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
@@ -700,6 +733,33 @@ mod tests {
700733
}
701734
}
702735

736+
#[test]
737+
fn test_try_cast_to_type_date_timestamp_lossy_not_allowed() {
738+
expect_cast(
739+
ScalarValue::Date32(Some(1)),
740+
DataType::Timestamp(TimeUnit::Second, None),
741+
ExpectedCast::NoValue,
742+
);
743+
744+
expect_cast(
745+
ScalarValue::Date64(Some(86_400_000)),
746+
DataType::Timestamp(TimeUnit::Millisecond, None),
747+
ExpectedCast::NoValue,
748+
);
749+
750+
expect_cast(
751+
ScalarValue::TimestampSecond(Some(86_400), None),
752+
DataType::Date32,
753+
ExpectedCast::NoValue,
754+
);
755+
756+
expect_cast(
757+
ScalarValue::TimestampMillisecond(Some(86_400_000), None),
758+
DataType::Date64,
759+
ExpectedCast::NoValue,
760+
);
761+
}
762+
703763
#[test]
704764
fn test_try_cast_to_type_unsupported() {
705765
// int64 to list

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,16 @@ fn math_decimal_coercion(
378378
let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?;
379379
Some((lhs_type, value_type))
380380
}
381+
(RunEndEncoded(_, field), _) => {
382+
let (value_type, rhs_type) =
383+
math_decimal_coercion(field.data_type(), rhs_type)?;
384+
Some((value_type, rhs_type))
385+
}
386+
(_, RunEndEncoded(_, field)) => {
387+
let (lhs_type, value_type) =
388+
math_decimal_coercion(lhs_type, field.data_type())?;
389+
Some((lhs_type, value_type))
390+
}
381391
(
382392
Null,
383393
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _),
@@ -1414,6 +1424,15 @@ fn mathematics_numerical_coercion(
14141424
(_, Dictionary(_, value_type)) => {
14151425
mathematics_numerical_coercion(lhs_type, value_type)
14161426
}
1427+
(RunEndEncoded(_, lhs_field), RunEndEncoded(_, rhs_field)) => {
1428+
mathematics_numerical_coercion(lhs_field.data_type(), rhs_field.data_type())
1429+
}
1430+
(RunEndEncoded(_, field), _) => {
1431+
mathematics_numerical_coercion(field.data_type(), rhs_type)
1432+
}
1433+
(_, RunEndEncoded(_, field)) => {
1434+
mathematics_numerical_coercion(lhs_type, field.data_type())
1435+
}
14171436
_ => numerical_coercion(lhs_type, rhs_type),
14181437
}
14191438
}
@@ -1493,6 +1512,15 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) ->
14931512
(_, Dictionary(_, value_type)) => {
14941513
lhs_type.is_numeric() && value_type.is_numeric()
14951514
}
1515+
(RunEndEncoded(_, lhs_field), RunEndEncoded(_, rhs_field)) => {
1516+
lhs_field.data_type().is_numeric() && rhs_field.data_type().is_numeric()
1517+
}
1518+
(RunEndEncoded(_, field), _) => {
1519+
field.data_type().is_numeric() && rhs_type.is_numeric()
1520+
}
1521+
(_, RunEndEncoded(_, field)) => {
1522+
lhs_type.is_numeric() && field.data_type().is_numeric()
1523+
}
14961524
_ => lhs_type.is_numeric() && rhs_type.is_numeric(),
14971525
}
14981526
}

datafusion/expr-common/src/type_coercion/binary/tests/run_end_encoded.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@
1616
// under the License.
1717

1818
use super::*;
19+
use DataType::*;
20+
21+
fn ree(value_type: DataType) -> DataType {
22+
RunEndEncoded(
23+
Arc::new(Field::new("run_ends", Int32, false)),
24+
Arc::new(Field::new("values", value_type, false)),
25+
)
26+
}
1927

2028
#[test]
2129
fn test_ree_type_coercion() {
22-
use DataType::*;
23-
2430
let lhs_type = RunEndEncoded(
2531
Arc::new(Field::new("run_ends", Int8, false)),
2632
Arc::new(Field::new("values", Int32, false)),
@@ -97,3 +103,29 @@ fn test_ree_type_coercion() {
97103
Some(rhs_type.clone())
98104
);
99105
}
106+
107+
#[test]
108+
fn test_ree_arithmetic_coercion() -> Result<()> {
109+
test_coercion_binary_rule!(ree(Int64), Int64, Operator::Plus, Int64);
110+
test_coercion_binary_rule!(Int64, ree(Int64), Operator::Multiply, Int64);
111+
test_coercion_binary_rule!(ree(Int32), ree(Int64), Operator::Plus, Int64);
112+
113+
// Decimal unwrapping through math_decimal_coercion
114+
let (lhs, rhs) =
115+
BinaryTypeCoercer::new(&ree(Decimal128(10, 2)), &Operator::Plus, &Int32)
116+
.get_input_types()?;
117+
assert_eq!(lhs, Decimal128(10, 2));
118+
assert_eq!(rhs, Decimal128(10, 0));
119+
120+
let (lhs, rhs) =
121+
BinaryTypeCoercer::new(&Int32, &Operator::Plus, &ree(Decimal128(10, 2)))
122+
.get_input_types()?;
123+
assert_eq!(lhs, Decimal128(10, 0));
124+
assert_eq!(rhs, Decimal128(10, 2));
125+
126+
let result =
127+
BinaryTypeCoercer::new(&ree(Utf8), &Operator::Plus, &Int32).get_input_types();
128+
assert!(result.is_err());
129+
130+
Ok(())
131+
}

datafusion/functions/benches/left_right.rs

Lines changed: 54 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,40 @@
1616
// under the License.
1717

1818
use std::hint::black_box;
19+
use std::ops::Range;
1920
use std::sync::Arc;
2021

2122
use arrow::array::{ArrayRef, Int64Array};
2223
use arrow::datatypes::{DataType, Field};
2324
use arrow::util::bench_util::{
2425
create_string_array_with_len, create_string_view_array_with_len,
2526
};
26-
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
27+
use criterion::{Criterion, criterion_group, criterion_main};
2728
use datafusion_common::config::ConfigOptions;
2829
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
2930
use datafusion_functions::unicode::{left, right};
3031

32+
const BATCH_SIZE: usize = 8192;
33+
3134
fn create_args(
32-
size: usize,
3335
str_len: usize,
34-
use_negative: bool,
36+
n_range: Range<i64>,
3537
is_string_view: bool,
3638
) -> Vec<ColumnarValue> {
3739
let string_arg = if is_string_view {
3840
ColumnarValue::Array(Arc::new(create_string_view_array_with_len(
39-
size, 0.1, str_len, true,
41+
BATCH_SIZE, 0.1, str_len, true,
4042
)))
4143
} else {
4244
ColumnarValue::Array(Arc::new(create_string_array_with_len::<i32>(
43-
size, 0.1, str_len,
45+
BATCH_SIZE, 0.1, str_len,
4446
)))
4547
};
4648

47-
// For negative n, we want to trigger the double-iteration code path
48-
let n_values: Vec<i64> = if use_negative {
49-
(0..size).map(|i| -((i % 10 + 1) as i64)).collect()
50-
} else {
51-
(0..size).map(|i| (i % 10 + 1) as i64).collect()
52-
};
49+
let n_span = (n_range.end - n_range.start) as usize;
50+
let n_values: Vec<i64> = (0..BATCH_SIZE)
51+
.map(|i| n_range.start + (i % n_span) as i64)
52+
.collect();
5353
let n_array = Arc::new(Int64Array::from(n_values));
5454

5555
vec![
@@ -59,68 +59,55 @@ fn create_args(
5959
}
6060

6161
fn criterion_benchmark(c: &mut Criterion) {
62-
let left_function = left();
63-
let right_function = right();
62+
// Short results (1-10 chars) produce inline StringView entries (≤12 bytes).
63+
// Long results (20-29 chars) produce out-of-line entries.
64+
let cases = [
65+
("short_result", 32, 1..11_i64),
66+
("long_result", 32, 20..30_i64),
67+
];
6468

65-
for function in [left_function, right_function] {
66-
for is_string_view in [false, true] {
67-
for is_negative in [false, true] {
68-
for size in [1024, 4096] {
69-
let function_name = function.name();
70-
let mut group =
71-
c.benchmark_group(format!("{function_name} size={size}"));
69+
for function in [left(), right()] {
70+
let mut group = c.benchmark_group(function.name().to_string());
7271

73-
let bench_name = format!(
74-
"{} {} n",
75-
if is_string_view {
76-
"string_view_array"
77-
} else {
78-
"string_array"
79-
},
80-
if is_negative { "negative" } else { "positive" },
81-
);
82-
let return_type = if is_string_view {
83-
DataType::Utf8View
84-
} else {
85-
DataType::Utf8
86-
};
87-
88-
let args = create_args(size, 32, is_negative, is_string_view);
89-
group.bench_function(BenchmarkId::new(bench_name, size), |b| {
90-
let arg_fields = args
91-
.iter()
92-
.enumerate()
93-
.map(|(idx, arg)| {
94-
Field::new(format!("arg_{idx}"), arg.data_type(), true)
95-
.into()
96-
})
97-
.collect::<Vec<_>>();
98-
let config_options = Arc::new(ConfigOptions::default());
72+
for is_string_view in [false, true] {
73+
let array_type = if is_string_view {
74+
"string_view"
75+
} else {
76+
"string"
77+
};
9978

100-
b.iter(|| {
101-
black_box(
102-
function
103-
.invoke_with_args(ScalarFunctionArgs {
104-
args: args.clone(),
105-
arg_fields: arg_fields.clone(),
106-
number_rows: size,
107-
return_field: Field::new(
108-
"f",
109-
return_type.clone(),
110-
true,
111-
)
112-
.into(),
113-
config_options: Arc::clone(&config_options),
114-
})
115-
.expect("should work"),
116-
)
117-
})
118-
});
79+
for (case_name, str_len, n_range) in &cases {
80+
let bench_name = format!("{array_type} {case_name}");
81+
let args = create_args(*str_len, n_range.clone(), is_string_view);
82+
let arg_fields: Vec<_> = args
83+
.iter()
84+
.enumerate()
85+
.map(|(idx, arg)| {
86+
Field::new(format!("arg_{idx}"), arg.data_type(), true).into()
87+
})
88+
.collect();
89+
let config_options = Arc::new(ConfigOptions::default());
90+
let return_field = Field::new("f", DataType::Utf8View, true).into();
11991

120-
group.finish();
121-
}
92+
group.bench_function(&bench_name, |b| {
93+
b.iter(|| {
94+
black_box(
95+
function
96+
.invoke_with_args(ScalarFunctionArgs {
97+
args: args.clone(),
98+
arg_fields: arg_fields.clone(),
99+
number_rows: BATCH_SIZE,
100+
return_field: Arc::clone(&return_field),
101+
config_options: Arc::clone(&config_options),
102+
})
103+
.expect("should work"),
104+
)
105+
})
106+
});
122107
}
123108
}
109+
110+
group.finish();
124111
}
125112
}
126113

0 commit comments

Comments
 (0)