@@ -2000,6 +2000,209 @@ impl ScalarValue {
20002000 }
20012001 }
20022002
2003+ #[ inline]
2004+ fn can_use_direct_add ( lhs : & ScalarValue , rhs : & ScalarValue ) -> bool {
2005+ matches ! (
2006+ ( lhs, rhs) ,
2007+ ( ScalarValue :: Int8 ( _) , ScalarValue :: Int8 ( _) )
2008+ | ( ScalarValue :: Int16 ( _) , ScalarValue :: Int16 ( _) )
2009+ | ( ScalarValue :: Int32 ( _) , ScalarValue :: Int32 ( _) )
2010+ | ( ScalarValue :: Int64 ( _) , ScalarValue :: Int64 ( _) )
2011+ | ( ScalarValue :: UInt8 ( _) , ScalarValue :: UInt8 ( _) )
2012+ | ( ScalarValue :: UInt16 ( _) , ScalarValue :: UInt16 ( _) )
2013+ | ( ScalarValue :: UInt32 ( _) , ScalarValue :: UInt32 ( _) )
2014+ | ( ScalarValue :: UInt64 ( _) , ScalarValue :: UInt64 ( _) )
2015+ | ( ScalarValue :: Float16 ( _) , ScalarValue :: Float16 ( _) )
2016+ | ( ScalarValue :: Float32 ( _) , ScalarValue :: Float32 ( _) )
2017+ | ( ScalarValue :: Float64 ( _) , ScalarValue :: Float64 ( _) )
2018+ | (
2019+ ScalarValue :: Decimal32 ( _, _, _) ,
2020+ ScalarValue :: Decimal32 ( _, _, _)
2021+ )
2022+ | (
2023+ ScalarValue :: Decimal64 ( _, _, _) ,
2024+ ScalarValue :: Decimal64 ( _, _, _)
2025+ )
2026+ | (
2027+ ScalarValue :: Decimal128 ( _, _, _) ,
2028+ ScalarValue :: Decimal128 ( _, _, _) ,
2029+ )
2030+ | (
2031+ ScalarValue :: Decimal256 ( _, _, _) ,
2032+ ScalarValue :: Decimal256 ( _, _, _) ,
2033+ )
2034+ )
2035+ }
2036+
2037+ #[ inline]
2038+ fn add_optional < T : ArrowNativeTypeOp > (
2039+ lhs : & mut Option < T > ,
2040+ rhs : Option < T > ,
2041+ checked : bool ,
2042+ ) -> Result < ( ) > {
2043+ match rhs {
2044+ Some ( rhs) => {
2045+ if let Some ( lhs) = lhs. as_mut ( ) {
2046+ * lhs = if checked {
2047+ lhs. add_checked ( rhs) . map_err ( |e| arrow_datafusion_err ! ( e) ) ?
2048+ } else {
2049+ lhs. add_wrapping ( rhs)
2050+ } ;
2051+ }
2052+ }
2053+ None => * lhs = None ,
2054+ }
2055+ Ok ( ( ) )
2056+ }
2057+
2058+ #[ inline]
2059+ fn add_decimal_values < T : DecimalType > (
2060+ lhs_value : & mut Option < T :: Native > ,
2061+ lhs_precision : & mut u8 ,
2062+ lhs_scale : & mut i8 ,
2063+ rhs_value : Option < T :: Native > ,
2064+ rhs_precision : u8 ,
2065+ rhs_scale : i8 ,
2066+ ) -> Result < ( ) >
2067+ where
2068+ T :: Native : ArrowNativeTypeOp ,
2069+ {
2070+ Self :: validate_decimal_or_internal_err :: < T > ( * lhs_precision, * lhs_scale) ?;
2071+ Self :: validate_decimal_or_internal_err :: < T > ( rhs_precision, rhs_scale) ?;
2072+
2073+ let result_scale = ( * lhs_scale) . max ( rhs_scale) ;
2074+ // Decimal scales can be negative, so use a wider signed type for the
2075+ // intermediate precision arithmetic.
2076+ let lhs_precision_delta = i16:: from ( * lhs_precision) - i16:: from ( * lhs_scale) ;
2077+ let rhs_precision_delta = i16:: from ( rhs_precision) - i16:: from ( rhs_scale) ;
2078+ let result_precision =
2079+ ( i16:: from ( result_scale) + lhs_precision_delta. max ( rhs_precision_delta) + 1 )
2080+ . min ( i16:: from ( T :: MAX_PRECISION ) ) as u8 ;
2081+
2082+ Self :: validate_decimal_or_internal_err :: < T > ( result_precision, result_scale) ?;
2083+
2084+ let lhs_mul = T :: Native :: usize_as ( 10 )
2085+ . pow_checked ( ( result_scale - * lhs_scale) as u32 )
2086+ . map_err ( |e| arrow_datafusion_err ! ( e) ) ?;
2087+ let rhs_mul = T :: Native :: usize_as ( 10 )
2088+ . pow_checked ( ( result_scale - rhs_scale) as u32 )
2089+ . map_err ( |e| arrow_datafusion_err ! ( e) ) ?;
2090+
2091+ let result_value = match ( * lhs_value, rhs_value) {
2092+ ( Some ( lhs_value) , Some ( rhs_value) ) => Some (
2093+ lhs_value
2094+ . mul_checked ( lhs_mul)
2095+ . and_then ( |lhs| {
2096+ rhs_value
2097+ . mul_checked ( rhs_mul)
2098+ . and_then ( |rhs| lhs. add_checked ( rhs) )
2099+ } )
2100+ . map_err ( |e| arrow_datafusion_err ! ( e) ) ?,
2101+ ) ,
2102+ _ => None ,
2103+ } ;
2104+
2105+ * lhs_value = result_value;
2106+ * lhs_precision = result_precision;
2107+ * lhs_scale = result_scale;
2108+
2109+ Ok ( ( ) )
2110+ }
2111+
2112+ #[ inline]
2113+ fn try_add_in_place_impl (
2114+ & mut self ,
2115+ other : & ScalarValue ,
2116+ checked : bool ,
2117+ ) -> Result < bool > {
2118+ match ( self , other) {
2119+ ( ScalarValue :: Int8 ( lhs) , ScalarValue :: Int8 ( rhs) ) => {
2120+ Self :: add_optional ( lhs, * rhs, checked) ?;
2121+ }
2122+ ( ScalarValue :: Int16 ( lhs) , ScalarValue :: Int16 ( rhs) ) => {
2123+ Self :: add_optional ( lhs, * rhs, checked) ?;
2124+ }
2125+ ( ScalarValue :: Int32 ( lhs) , ScalarValue :: Int32 ( rhs) ) => {
2126+ Self :: add_optional ( lhs, * rhs, checked) ?;
2127+ }
2128+ ( ScalarValue :: Int64 ( lhs) , ScalarValue :: Int64 ( rhs) ) => {
2129+ Self :: add_optional ( lhs, * rhs, checked) ?;
2130+ }
2131+ ( ScalarValue :: UInt8 ( lhs) , ScalarValue :: UInt8 ( rhs) ) => {
2132+ Self :: add_optional ( lhs, * rhs, checked) ?;
2133+ }
2134+ ( ScalarValue :: UInt16 ( lhs) , ScalarValue :: UInt16 ( rhs) ) => {
2135+ Self :: add_optional ( lhs, * rhs, checked) ?;
2136+ }
2137+ ( ScalarValue :: UInt32 ( lhs) , ScalarValue :: UInt32 ( rhs) ) => {
2138+ Self :: add_optional ( lhs, * rhs, checked) ?;
2139+ }
2140+ ( ScalarValue :: UInt64 ( lhs) , ScalarValue :: UInt64 ( rhs) ) => {
2141+ Self :: add_optional ( lhs, * rhs, checked) ?;
2142+ }
2143+ ( ScalarValue :: Float16 ( lhs) , ScalarValue :: Float16 ( rhs) ) => {
2144+ Self :: add_optional ( lhs, * rhs, checked) ?;
2145+ }
2146+ ( ScalarValue :: Float32 ( lhs) , ScalarValue :: Float32 ( rhs) ) => {
2147+ Self :: add_optional ( lhs, * rhs, checked) ?;
2148+ }
2149+ ( ScalarValue :: Float64 ( lhs) , ScalarValue :: Float64 ( rhs) ) => {
2150+ Self :: add_optional ( lhs, * rhs, checked) ?;
2151+ }
2152+ (
2153+ ScalarValue :: Decimal32 ( lhs, p, s) ,
2154+ ScalarValue :: Decimal32 ( rhs, rhs_p, rhs_s) ,
2155+ ) => {
2156+ Self :: add_decimal_values :: < Decimal32Type > (
2157+ lhs, p, s, * rhs, * rhs_p, * rhs_s,
2158+ ) ?;
2159+ }
2160+ (
2161+ ScalarValue :: Decimal64 ( lhs, p, s) ,
2162+ ScalarValue :: Decimal64 ( rhs, rhs_p, rhs_s) ,
2163+ ) => {
2164+ Self :: add_decimal_values :: < Decimal64Type > (
2165+ lhs, p, s, * rhs, * rhs_p, * rhs_s,
2166+ ) ?;
2167+ }
2168+ (
2169+ ScalarValue :: Decimal128 ( lhs, p, s) ,
2170+ ScalarValue :: Decimal128 ( rhs, rhs_p, rhs_s) ,
2171+ ) => {
2172+ Self :: add_decimal_values :: < Decimal128Type > (
2173+ lhs, p, s, * rhs, * rhs_p, * rhs_s,
2174+ ) ?;
2175+ }
2176+ (
2177+ ScalarValue :: Decimal256 ( lhs, p, s) ,
2178+ ScalarValue :: Decimal256 ( rhs, rhs_p, rhs_s) ,
2179+ ) => {
2180+ Self :: add_decimal_values :: < Decimal256Type > (
2181+ lhs, p, s, * rhs, * rhs_p, * rhs_s,
2182+ ) ?;
2183+ }
2184+ _ => return Ok ( false ) ,
2185+ }
2186+
2187+ Ok ( true )
2188+ }
2189+
2190+ #[ inline]
2191+ pub ( crate ) fn try_add_wrapping_in_place (
2192+ & mut self ,
2193+ other : & ScalarValue ,
2194+ ) -> Result < bool > {
2195+ self . try_add_in_place_impl ( other, false )
2196+ }
2197+
2198+ #[ inline]
2199+ pub ( crate ) fn try_add_checked_in_place (
2200+ & mut self ,
2201+ other : & ScalarValue ,
2202+ ) -> Result < bool > {
2203+ self . try_add_in_place_impl ( other, true )
2204+ }
2205+
20032206 /// Calculate arithmetic negation for a scalar value
20042207 pub fn arithmetic_negate ( & self ) -> Result < Self > {
20052208 fn neg_checked_with_ctx < T : ArrowNativeTypeOp > (
@@ -2135,7 +2338,16 @@ impl ScalarValue {
21352338 /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code
21362339 /// should operate on Arrays directly, using vectorized array kernels
21372340 pub fn add < T : Borrow < ScalarValue > > ( & self , other : T ) -> Result < ScalarValue > {
2138- let r = add_wrapping ( & self . to_scalar ( ) ?, & other. borrow ( ) . to_scalar ( ) ?) ?;
2341+ let other = other. borrow ( ) ;
2342+ if Self :: can_use_direct_add ( self , other) {
2343+ let mut result = self . clone ( ) ;
2344+ if result. try_add_wrapping_in_place ( other) ? {
2345+ return Ok ( result) ;
2346+ }
2347+ debug_assert ! ( false , "fast-path eligibility drifted from implementation" ) ;
2348+ }
2349+
2350+ let r = add_wrapping ( & self . to_scalar ( ) ?, & other. to_scalar ( ) ?) ?;
21392351 Self :: try_from_array ( r. as_ref ( ) , 0 )
21402352 }
21412353
@@ -2144,7 +2356,16 @@ impl ScalarValue {
21442356 /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code
21452357 /// should operate on Arrays directly, using vectorized array kernels
21462358 pub fn add_checked < T : Borrow < ScalarValue > > ( & self , other : T ) -> Result < ScalarValue > {
2147- let r = add ( & self . to_scalar ( ) ?, & other. borrow ( ) . to_scalar ( ) ?) ?;
2359+ let other = other. borrow ( ) ;
2360+ if Self :: can_use_direct_add ( self , other) {
2361+ let mut result = self . clone ( ) ;
2362+ if result. try_add_checked_in_place ( other) ? {
2363+ return Ok ( result) ;
2364+ }
2365+ debug_assert ! ( false , "fast-path eligibility drifted from implementation" ) ;
2366+ }
2367+
2368+ let r = add ( & self . to_scalar ( ) ?, & other. to_scalar ( ) ?) ?;
21482369 Self :: try_from_array ( r. as_ref ( ) , 0 )
21492370 }
21502371
@@ -5943,6 +6164,68 @@ mod tests {
59436164 Ok ( ( ) )
59446165 }
59456166
6167+ #[ test]
6168+ fn scalar_add_trait_null_test ( ) -> Result < ( ) > {
6169+ let int_value = ScalarValue :: Int32 ( Some ( 42 ) ) ;
6170+
6171+ assert_eq ! (
6172+ int_value. add( ScalarValue :: Int32 ( None ) ) ?,
6173+ ScalarValue :: Int32 ( None )
6174+ ) ;
6175+
6176+ Ok ( ( ) )
6177+ }
6178+
6179+ #[ test]
6180+ fn scalar_add_trait_wrapping_overflow_test ( ) -> Result < ( ) > {
6181+ let int_value = ScalarValue :: Int32 ( Some ( i32:: MAX ) ) ;
6182+ let one = ScalarValue :: Int32 ( Some ( 1 ) ) ;
6183+
6184+ assert_eq ! ( int_value. add( one) ?, ScalarValue :: Int32 ( Some ( i32 :: MIN ) ) ) ;
6185+
6186+ Ok ( ( ) )
6187+ }
6188+
6189+ #[ test]
6190+ fn scalar_add_trait_decimal_scale_test ( ) -> Result < ( ) > {
6191+ let decimal = ScalarValue :: Decimal128 ( Some ( 123 ) , 10 , 2 ) ;
6192+ let decimal_2 = ScalarValue :: Decimal128 ( Some ( 4 ) , 9 , 1 ) ;
6193+
6194+ assert_eq ! (
6195+ decimal. add( decimal_2) ?,
6196+ ScalarValue :: Decimal128 ( Some ( 163 ) , 11 , 2 )
6197+ ) ;
6198+
6199+ Ok ( ( ) )
6200+ }
6201+
6202+ #[ test]
6203+ fn scalar_add_trait_decimal256_scale_test ( ) -> Result < ( ) > {
6204+ let decimal = ScalarValue :: Decimal256 ( Some ( i256:: from ( 123 ) ) , 10 , 2 ) ;
6205+ let decimal_2 = ScalarValue :: Decimal256 ( Some ( i256:: from ( 4 ) ) , 9 , 1 ) ;
6206+
6207+ assert_eq ! (
6208+ decimal. add( decimal_2) ?,
6209+ ScalarValue :: Decimal256 ( Some ( i256:: from( 163 ) ) , 11 , 2 )
6210+ ) ;
6211+
6212+ Ok ( ( ) )
6213+ }
6214+
6215+ #[ test]
6216+ fn scalar_add_trait_decimal_negative_scale_test ( ) -> Result < ( ) > {
6217+ let decimal = ScalarValue :: Decimal128 ( Some ( 1 ) , DECIMAL128_MAX_PRECISION , i8:: MIN ) ;
6218+ let decimal_2 =
6219+ ScalarValue :: Decimal128 ( Some ( 1 ) , DECIMAL128_MAX_PRECISION , i8:: MIN ) ;
6220+
6221+ assert_eq ! (
6222+ decimal. add( decimal_2) ?,
6223+ ScalarValue :: Decimal128 ( Some ( 2 ) , DECIMAL128_MAX_PRECISION , i8 :: MIN )
6224+ ) ;
6225+
6226+ Ok ( ( ) )
6227+ }
6228+
59466229 #[ test]
59476230 fn scalar_sub_trait_test ( ) -> Result < ( ) > {
59486231 let float_value = ScalarValue :: Float64 ( Some ( 123. ) ) ;
@@ -6042,6 +6325,43 @@ mod tests {
60426325 Ok ( ( ) )
60436326 }
60446327
6328+ #[ test]
6329+ fn scalar_decimal_add_overflow_test ( ) {
6330+ check_scalar_decimal_add_overflow :: < Decimal128Type > (
6331+ ScalarValue :: Decimal128 ( Some ( i128:: MAX ) , DECIMAL128_MAX_PRECISION , 0 ) ,
6332+ ScalarValue :: Decimal128 ( Some ( 1 ) , DECIMAL128_MAX_PRECISION , 0 ) ,
6333+ ) ;
6334+ check_scalar_decimal_add_overflow :: < Decimal256Type > (
6335+ ScalarValue :: Decimal256 ( Some ( i256:: MAX ) , DECIMAL256_MAX_PRECISION , 0 ) ,
6336+ ScalarValue :: Decimal256 ( Some ( i256:: ONE ) , DECIMAL256_MAX_PRECISION , 0 ) ,
6337+ ) ;
6338+ }
6339+
6340+ #[ test]
6341+ fn scalar_decimal_in_place_add_error_preserves_lhs ( ) {
6342+ let mut lhs =
6343+ ScalarValue :: Decimal128 ( Some ( i128:: MAX ) , DECIMAL128_MAX_PRECISION , 0 ) ;
6344+ let original = lhs. clone ( ) ;
6345+
6346+ let err = lhs
6347+ . try_add_checked_in_place ( & ScalarValue :: Decimal128 (
6348+ Some ( 1 ) ,
6349+ DECIMAL128_MAX_PRECISION ,
6350+ 0 ,
6351+ ) )
6352+ . unwrap_err ( )
6353+ . strip_backtrace ( ) ;
6354+
6355+ assert_eq ! (
6356+ err,
6357+ format!(
6358+ "Arrow error: Arithmetic overflow: Overflow happened on: {} + 1" ,
6359+ i128 :: MAX
6360+ )
6361+ ) ;
6362+ assert_eq ! ( lhs, original) ;
6363+ }
6364+
60456365 // Verifies that ScalarValue has the same behavior with compute kernel when it overflows.
60466366 fn check_scalar_add_overflow < T > ( left : ScalarValue , right : ScalarValue )
60476367 where
@@ -6058,6 +6378,22 @@ mod tests {
60586378 assert_eq ! ( scalar_result. is_ok( ) , arrow_result. is_ok( ) ) ;
60596379 }
60606380
6381+ // Verifies the decimal fast path preserves the same overflow behavior as Arrow kernels.
6382+ fn check_scalar_decimal_add_overflow < T > ( left : ScalarValue , right : ScalarValue )
6383+ where
6384+ T : ArrowPrimitiveType ,
6385+ {
6386+ let scalar_result = left. add ( & right) ;
6387+
6388+ let left_array = left. to_array ( ) . expect ( "Failed to convert to array" ) ;
6389+ let right_array = right. to_array ( ) . expect ( "Failed to convert to array" ) ;
6390+ let arrow_left_array = left_array. as_primitive :: < T > ( ) ;
6391+ let arrow_right_array = right_array. as_primitive :: < T > ( ) ;
6392+ let arrow_result = add_wrapping ( arrow_left_array, arrow_right_array) ;
6393+
6394+ assert_eq ! ( scalar_result. is_ok( ) , arrow_result. is_ok( ) ) ;
6395+ }
6396+
60616397 #[ test]
60626398 fn test_interval_add_timestamp ( ) -> Result < ( ) > {
60636399 let interval = ScalarValue :: IntervalMonthDayNano ( Some ( IntervalMonthDayNano {
0 commit comments