@@ -19,6 +19,7 @@ use std::any::Any;
1919use std:: str:: FromStr ;
2020use std:: sync:: Arc ;
2121
22+ use arrow:: array:: timezone:: Tz ;
2223use arrow:: array:: { Array , ArrayRef , Float64Array , Int32Array } ;
2324use arrow:: compute:: kernels:: cast_utils:: IntervalUnit ;
2425use arrow:: compute:: { DatePart , binary, date_part} ;
@@ -27,8 +28,10 @@ use arrow::datatypes::DataType::{
2728} ;
2829use arrow:: datatypes:: TimeUnit :: { Microsecond , Millisecond , Nanosecond , Second } ;
2930use arrow:: datatypes:: {
30- DataType , Field , FieldRef , IntervalUnit as ArrowIntervalUnit , TimeUnit ,
31+ DataType , Date32Type , Date64Type , Field , FieldRef , IntervalUnit as ArrowIntervalUnit ,
32+ TimeUnit ,
3133} ;
34+ use chrono:: { Datelike , NaiveDate , TimeZone , Utc } ;
3235use datafusion_common:: types:: { NativeType , logical_date} ;
3336
3437use datafusion_common:: {
@@ -44,9 +47,11 @@ use datafusion_common::{
4447 types:: logical_string,
4548 utils:: take_function_args,
4649} ;
50+ use datafusion_expr:: preimage:: PreimageResult ;
51+ use datafusion_expr:: simplify:: SimplifyContext ;
4752use datafusion_expr:: {
48- ColumnarValue , Documentation , ReturnFieldArgs , ScalarUDFImpl , Signature ,
49- TypeSignature , Volatility ,
53+ ColumnarValue , Documentation , Expr , ReturnFieldArgs , ScalarUDFImpl , Signature ,
54+ TypeSignature , Volatility , interval_arithmetic ,
5055} ;
5156use datafusion_expr_common:: signature:: { Coercion , TypeSignatureClass } ;
5257use datafusion_macros:: user_doc;
@@ -237,6 +242,71 @@ impl ScalarUDFImpl for DatePartFunc {
237242 } )
238243 }
239244
245+ // Only casting the year is supported since pruning other IntervalUnit is not possible
246+ // date_part(col, YEAR) = 2024 => col >= '2024-01-01' and col < '2025-01-01'
247+ // But for anything less than YEAR simplifying is not possible without specifying the bigger interval
248+ // date_part(col, MONTH) = 1 => col = '2023-01-01' or col = '2024-01-01' or ... or col = '3000-01-01'
249+ fn preimage (
250+ & self ,
251+ args : & [ Expr ] ,
252+ lit_expr : & Expr ,
253+ info : & SimplifyContext ,
254+ ) -> Result < PreimageResult > {
255+ let [ part, col_expr] = take_function_args ( self . name ( ) , args) ?;
256+
257+ // Get the interval unit from the part argument
258+ let interval_unit = part
259+ . as_literal ( )
260+ . and_then ( |sv| sv. try_as_str ( ) . flatten ( ) )
261+ . map ( part_normalization)
262+ . and_then ( |s| IntervalUnit :: from_str ( s) . ok ( ) ) ;
263+
264+ // only support extracting year
265+ match interval_unit {
266+ Some ( IntervalUnit :: Year ) => ( ) ,
267+ _ => return Ok ( PreimageResult :: None ) ,
268+ }
269+
270+ // Check if the argument is a literal (e.g. date_part(YEAR, col) = 2024)
271+ let Some ( argument_literal) = lit_expr. as_literal ( ) else {
272+ return Ok ( PreimageResult :: None ) ;
273+ } ;
274+
275+ // Extract i32 year from Scalar value
276+ let year = match argument_literal {
277+ ScalarValue :: Int32 ( Some ( y) ) => * y,
278+ _ => return Ok ( PreimageResult :: None ) ,
279+ } ;
280+
281+ // Can only extract year from Date32/64 and Timestamp column
282+ let target_type = match info. get_data_type ( col_expr) ? {
283+ Date32 | Date64 | Timestamp ( _, _) => & info. get_data_type ( col_expr) ?,
284+ _ => return Ok ( PreimageResult :: None ) ,
285+ } ;
286+
287+ // Compute the Interval bounds
288+ let Some ( start_time) = NaiveDate :: from_ymd_opt ( year, 1 , 1 ) else {
289+ return Ok ( PreimageResult :: None ) ;
290+ } ;
291+ let Some ( end_time) = start_time. with_year ( year + 1 ) else {
292+ return Ok ( PreimageResult :: None ) ;
293+ } ;
294+
295+ // Convert to ScalarValues
296+ let ( Some ( lower) , Some ( upper) ) = (
297+ date_to_scalar ( start_time, target_type) ,
298+ date_to_scalar ( end_time, target_type) ,
299+ ) else {
300+ return Ok ( PreimageResult :: None ) ;
301+ } ;
302+ let interval = Box :: new ( interval_arithmetic:: Interval :: try_new ( lower, upper) ?) ;
303+
304+ Ok ( PreimageResult :: Range {
305+ expr : col_expr. clone ( ) ,
306+ interval,
307+ } )
308+ }
309+
240310 fn aliases ( & self ) -> & [ String ] {
241311 & self . aliases
242312 }
@@ -251,6 +321,52 @@ fn is_epoch(part: &str) -> bool {
251321 matches ! ( part. to_lowercase( ) . as_str( ) , "epoch" )
252322}
253323
324+ fn date_to_scalar ( date : NaiveDate , target_type : & DataType ) -> Option < ScalarValue > {
325+ Some ( match target_type {
326+ Date32 => ScalarValue :: Date32 ( Some ( Date32Type :: from_naive_date ( date) ) ) ,
327+ Date64 => ScalarValue :: Date64 ( Some ( Date64Type :: from_naive_date ( date) ) ) ,
328+
329+ Timestamp ( unit, tz_opt) => {
330+ let naive_midnight = date. and_hms_opt ( 0 , 0 , 0 ) ?;
331+
332+ let utc_dt = if let Some ( tz_str) = tz_opt {
333+ let tz: Tz = tz_str. parse ( ) . ok ( ) ?;
334+
335+ let local = tz. from_local_datetime ( & naive_midnight) ;
336+
337+ let local_dt = match local {
338+ chrono:: offset:: LocalResult :: Single ( dt) => dt,
339+ chrono:: offset:: LocalResult :: Ambiguous ( dt1, _dt2) => dt1,
340+ chrono:: offset:: LocalResult :: None => local. earliest ( ) ?,
341+ } ;
342+
343+ local_dt. with_timezone ( & Utc )
344+ } else {
345+ Utc . from_utc_datetime ( & naive_midnight)
346+ } ;
347+
348+ match unit {
349+ Second => {
350+ ScalarValue :: TimestampSecond ( Some ( utc_dt. timestamp ( ) ) , tz_opt. clone ( ) )
351+ }
352+ Millisecond => ScalarValue :: TimestampMillisecond (
353+ Some ( utc_dt. timestamp_millis ( ) ) ,
354+ tz_opt. clone ( ) ,
355+ ) ,
356+ Microsecond => ScalarValue :: TimestampMicrosecond (
357+ Some ( utc_dt. timestamp_micros ( ) ) ,
358+ tz_opt. clone ( ) ,
359+ ) ,
360+ Nanosecond => ScalarValue :: TimestampNanosecond (
361+ Some ( utc_dt. timestamp_nanos_opt ( ) ?) ,
362+ tz_opt. clone ( ) ,
363+ ) ,
364+ }
365+ }
366+ _ => return None ,
367+ } )
368+ }
369+
254370// Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error
255371fn part_normalization ( part : & str ) -> & str {
256372 part. strip_prefix ( |c| c == '\'' || c == '\"' )
0 commit comments