Skip to content

Commit 35e78ca

Browse files
sdf-jklalamb
andauthored
Optimize the evaluation of date_part(<col>) == <constant> when pushed down (#19733)
## Which issue does this PR close? - closes #19889. ## Rationale for this change Check issue. ## What changes are included in this PR? Added `preimage` impl for `date_part` udf. Added sqllogictests for the impl. ## Are these changes tested? Yes, sqllogictests. ## Are there any user-facing changes? No --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 81f7a87 commit 35e78ca

2 files changed

Lines changed: 657 additions & 4 deletions

File tree

datafusion/functions/src/datetime/date_part.rs

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use std::any::Any;
1919
use std::str::FromStr;
2020
use std::sync::Arc;
2121

22+
use arrow::array::timezone::Tz;
2223
use arrow::array::{Array, ArrayRef, Float64Array, Int32Array};
2324
use arrow::compute::kernels::cast_utils::IntervalUnit;
2425
use arrow::compute::{DatePart, binary, date_part};
@@ -27,8 +28,10 @@ use arrow::datatypes::DataType::{
2728
};
2829
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
2930
use arrow::datatypes::{
30-
DataType, Field, FieldRef, IntervalUnit as ArrowIntervalUnit, TimeUnit,
31+
DataType, Date32Type, Date64Type, Field, FieldRef, IntervalUnit as ArrowIntervalUnit,
32+
TimeUnit,
3133
};
34+
use chrono::{Datelike, NaiveDate, TimeZone, Utc};
3235
use datafusion_common::types::{NativeType, logical_date};
3336

3437
use datafusion_common::{
@@ -44,9 +47,11 @@ use datafusion_common::{
4447
types::logical_string,
4548
utils::take_function_args,
4649
};
50+
use datafusion_expr::preimage::PreimageResult;
51+
use datafusion_expr::simplify::SimplifyContext;
4752
use datafusion_expr::{
48-
ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature,
49-
TypeSignature, Volatility,
53+
ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature,
54+
TypeSignature, Volatility, interval_arithmetic,
5055
};
5156
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
5257
use datafusion_macros::user_doc;
@@ -237,6 +242,71 @@ impl ScalarUDFImpl for DatePartFunc {
237242
})
238243
}
239244

245+
// Only casting the year is supported since pruning other IntervalUnit is not possible
246+
// date_part(col, YEAR) = 2024 => col >= '2024-01-01' and col < '2025-01-01'
247+
// But for anything less than YEAR simplifying is not possible without specifying the bigger interval
248+
// date_part(col, MONTH) = 1 => col = '2023-01-01' or col = '2024-01-01' or ... or col = '3000-01-01'
249+
fn preimage(
250+
&self,
251+
args: &[Expr],
252+
lit_expr: &Expr,
253+
info: &SimplifyContext,
254+
) -> Result<PreimageResult> {
255+
let [part, col_expr] = take_function_args(self.name(), args)?;
256+
257+
// Get the interval unit from the part argument
258+
let interval_unit = part
259+
.as_literal()
260+
.and_then(|sv| sv.try_as_str().flatten())
261+
.map(part_normalization)
262+
.and_then(|s| IntervalUnit::from_str(s).ok());
263+
264+
// only support extracting year
265+
match interval_unit {
266+
Some(IntervalUnit::Year) => (),
267+
_ => return Ok(PreimageResult::None),
268+
}
269+
270+
// Check if the argument is a literal (e.g. date_part(YEAR, col) = 2024)
271+
let Some(argument_literal) = lit_expr.as_literal() else {
272+
return Ok(PreimageResult::None);
273+
};
274+
275+
// Extract i32 year from Scalar value
276+
let year = match argument_literal {
277+
ScalarValue::Int32(Some(y)) => *y,
278+
_ => return Ok(PreimageResult::None),
279+
};
280+
281+
// Can only extract year from Date32/64 and Timestamp column
282+
let target_type = match info.get_data_type(col_expr)? {
283+
Date32 | Date64 | Timestamp(_, _) => &info.get_data_type(col_expr)?,
284+
_ => return Ok(PreimageResult::None),
285+
};
286+
287+
// Compute the Interval bounds
288+
let Some(start_time) = NaiveDate::from_ymd_opt(year, 1, 1) else {
289+
return Ok(PreimageResult::None);
290+
};
291+
let Some(end_time) = start_time.with_year(year + 1) else {
292+
return Ok(PreimageResult::None);
293+
};
294+
295+
// Convert to ScalarValues
296+
let (Some(lower), Some(upper)) = (
297+
date_to_scalar(start_time, target_type),
298+
date_to_scalar(end_time, target_type),
299+
) else {
300+
return Ok(PreimageResult::None);
301+
};
302+
let interval = Box::new(interval_arithmetic::Interval::try_new(lower, upper)?);
303+
304+
Ok(PreimageResult::Range {
305+
expr: col_expr.clone(),
306+
interval,
307+
})
308+
}
309+
240310
fn aliases(&self) -> &[String] {
241311
&self.aliases
242312
}
@@ -251,6 +321,52 @@ fn is_epoch(part: &str) -> bool {
251321
matches!(part.to_lowercase().as_str(), "epoch")
252322
}
253323

324+
fn date_to_scalar(date: NaiveDate, target_type: &DataType) -> Option<ScalarValue> {
325+
Some(match target_type {
326+
Date32 => ScalarValue::Date32(Some(Date32Type::from_naive_date(date))),
327+
Date64 => ScalarValue::Date64(Some(Date64Type::from_naive_date(date))),
328+
329+
Timestamp(unit, tz_opt) => {
330+
let naive_midnight = date.and_hms_opt(0, 0, 0)?;
331+
332+
let utc_dt = if let Some(tz_str) = tz_opt {
333+
let tz: Tz = tz_str.parse().ok()?;
334+
335+
let local = tz.from_local_datetime(&naive_midnight);
336+
337+
let local_dt = match local {
338+
chrono::offset::LocalResult::Single(dt) => dt,
339+
chrono::offset::LocalResult::Ambiguous(dt1, _dt2) => dt1,
340+
chrono::offset::LocalResult::None => local.earliest()?,
341+
};
342+
343+
local_dt.with_timezone(&Utc)
344+
} else {
345+
Utc.from_utc_datetime(&naive_midnight)
346+
};
347+
348+
match unit {
349+
Second => {
350+
ScalarValue::TimestampSecond(Some(utc_dt.timestamp()), tz_opt.clone())
351+
}
352+
Millisecond => ScalarValue::TimestampMillisecond(
353+
Some(utc_dt.timestamp_millis()),
354+
tz_opt.clone(),
355+
),
356+
Microsecond => ScalarValue::TimestampMicrosecond(
357+
Some(utc_dt.timestamp_micros()),
358+
tz_opt.clone(),
359+
),
360+
Nanosecond => ScalarValue::TimestampNanosecond(
361+
Some(utc_dt.timestamp_nanos_opt()?),
362+
tz_opt.clone(),
363+
),
364+
}
365+
}
366+
_ => return None,
367+
})
368+
}
369+
254370
// Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error
255371
fn part_normalization(part: &str) -> &str {
256372
part.strip_prefix(|c| c == '\'' || c == '\"')

0 commit comments

Comments
 (0)