Skip to content

Commit d63326d

Browse files
authored
Merge branch 'main' into feat_migrate_ffi_to_stabby
2 parents e866aba + 3bdcdf5 commit d63326d

3 files changed

Lines changed: 437 additions & 120 deletions

File tree

datafusion/common/src/scalar/mod.rs

Lines changed: 338 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,209 @@ impl ScalarValue {
20002000
}
20012001
}
20022002

2003+
#[inline]
2004+
fn can_use_direct_add(lhs: &ScalarValue, rhs: &ScalarValue) -> bool {
2005+
matches!(
2006+
(lhs, rhs),
2007+
(ScalarValue::Int8(_), ScalarValue::Int8(_))
2008+
| (ScalarValue::Int16(_), ScalarValue::Int16(_))
2009+
| (ScalarValue::Int32(_), ScalarValue::Int32(_))
2010+
| (ScalarValue::Int64(_), ScalarValue::Int64(_))
2011+
| (ScalarValue::UInt8(_), ScalarValue::UInt8(_))
2012+
| (ScalarValue::UInt16(_), ScalarValue::UInt16(_))
2013+
| (ScalarValue::UInt32(_), ScalarValue::UInt32(_))
2014+
| (ScalarValue::UInt64(_), ScalarValue::UInt64(_))
2015+
| (ScalarValue::Float16(_), ScalarValue::Float16(_))
2016+
| (ScalarValue::Float32(_), ScalarValue::Float32(_))
2017+
| (ScalarValue::Float64(_), ScalarValue::Float64(_))
2018+
| (
2019+
ScalarValue::Decimal32(_, _, _),
2020+
ScalarValue::Decimal32(_, _, _)
2021+
)
2022+
| (
2023+
ScalarValue::Decimal64(_, _, _),
2024+
ScalarValue::Decimal64(_, _, _)
2025+
)
2026+
| (
2027+
ScalarValue::Decimal128(_, _, _),
2028+
ScalarValue::Decimal128(_, _, _),
2029+
)
2030+
| (
2031+
ScalarValue::Decimal256(_, _, _),
2032+
ScalarValue::Decimal256(_, _, _),
2033+
)
2034+
)
2035+
}
2036+
2037+
#[inline]
2038+
fn add_optional<T: ArrowNativeTypeOp>(
2039+
lhs: &mut Option<T>,
2040+
rhs: Option<T>,
2041+
checked: bool,
2042+
) -> Result<()> {
2043+
match rhs {
2044+
Some(rhs) => {
2045+
if let Some(lhs) = lhs.as_mut() {
2046+
*lhs = if checked {
2047+
lhs.add_checked(rhs).map_err(|e| arrow_datafusion_err!(e))?
2048+
} else {
2049+
lhs.add_wrapping(rhs)
2050+
};
2051+
}
2052+
}
2053+
None => *lhs = None,
2054+
}
2055+
Ok(())
2056+
}
2057+
2058+
#[inline]
2059+
fn add_decimal_values<T: DecimalType>(
2060+
lhs_value: &mut Option<T::Native>,
2061+
lhs_precision: &mut u8,
2062+
lhs_scale: &mut i8,
2063+
rhs_value: Option<T::Native>,
2064+
rhs_precision: u8,
2065+
rhs_scale: i8,
2066+
) -> Result<()>
2067+
where
2068+
T::Native: ArrowNativeTypeOp,
2069+
{
2070+
Self::validate_decimal_or_internal_err::<T>(*lhs_precision, *lhs_scale)?;
2071+
Self::validate_decimal_or_internal_err::<T>(rhs_precision, rhs_scale)?;
2072+
2073+
let result_scale = (*lhs_scale).max(rhs_scale);
2074+
// Decimal scales can be negative, so use a wider signed type for the
2075+
// intermediate precision arithmetic.
2076+
let lhs_precision_delta = i16::from(*lhs_precision) - i16::from(*lhs_scale);
2077+
let rhs_precision_delta = i16::from(rhs_precision) - i16::from(rhs_scale);
2078+
let result_precision =
2079+
(i16::from(result_scale) + lhs_precision_delta.max(rhs_precision_delta) + 1)
2080+
.min(i16::from(T::MAX_PRECISION)) as u8;
2081+
2082+
Self::validate_decimal_or_internal_err::<T>(result_precision, result_scale)?;
2083+
2084+
let lhs_mul = T::Native::usize_as(10)
2085+
.pow_checked((result_scale - *lhs_scale) as u32)
2086+
.map_err(|e| arrow_datafusion_err!(e))?;
2087+
let rhs_mul = T::Native::usize_as(10)
2088+
.pow_checked((result_scale - rhs_scale) as u32)
2089+
.map_err(|e| arrow_datafusion_err!(e))?;
2090+
2091+
let result_value = match (*lhs_value, rhs_value) {
2092+
(Some(lhs_value), Some(rhs_value)) => Some(
2093+
lhs_value
2094+
.mul_checked(lhs_mul)
2095+
.and_then(|lhs| {
2096+
rhs_value
2097+
.mul_checked(rhs_mul)
2098+
.and_then(|rhs| lhs.add_checked(rhs))
2099+
})
2100+
.map_err(|e| arrow_datafusion_err!(e))?,
2101+
),
2102+
_ => None,
2103+
};
2104+
2105+
*lhs_value = result_value;
2106+
*lhs_precision = result_precision;
2107+
*lhs_scale = result_scale;
2108+
2109+
Ok(())
2110+
}
2111+
2112+
#[inline]
2113+
fn try_add_in_place_impl(
2114+
&mut self,
2115+
other: &ScalarValue,
2116+
checked: bool,
2117+
) -> Result<bool> {
2118+
match (self, other) {
2119+
(ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
2120+
Self::add_optional(lhs, *rhs, checked)?;
2121+
}
2122+
(ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
2123+
Self::add_optional(lhs, *rhs, checked)?;
2124+
}
2125+
(ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
2126+
Self::add_optional(lhs, *rhs, checked)?;
2127+
}
2128+
(ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
2129+
Self::add_optional(lhs, *rhs, checked)?;
2130+
}
2131+
(ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
2132+
Self::add_optional(lhs, *rhs, checked)?;
2133+
}
2134+
(ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
2135+
Self::add_optional(lhs, *rhs, checked)?;
2136+
}
2137+
(ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
2138+
Self::add_optional(lhs, *rhs, checked)?;
2139+
}
2140+
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
2141+
Self::add_optional(lhs, *rhs, checked)?;
2142+
}
2143+
(ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
2144+
Self::add_optional(lhs, *rhs, checked)?;
2145+
}
2146+
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
2147+
Self::add_optional(lhs, *rhs, checked)?;
2148+
}
2149+
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
2150+
Self::add_optional(lhs, *rhs, checked)?;
2151+
}
2152+
(
2153+
ScalarValue::Decimal32(lhs, p, s),
2154+
ScalarValue::Decimal32(rhs, rhs_p, rhs_s),
2155+
) => {
2156+
Self::add_decimal_values::<Decimal32Type>(
2157+
lhs, p, s, *rhs, *rhs_p, *rhs_s,
2158+
)?;
2159+
}
2160+
(
2161+
ScalarValue::Decimal64(lhs, p, s),
2162+
ScalarValue::Decimal64(rhs, rhs_p, rhs_s),
2163+
) => {
2164+
Self::add_decimal_values::<Decimal64Type>(
2165+
lhs, p, s, *rhs, *rhs_p, *rhs_s,
2166+
)?;
2167+
}
2168+
(
2169+
ScalarValue::Decimal128(lhs, p, s),
2170+
ScalarValue::Decimal128(rhs, rhs_p, rhs_s),
2171+
) => {
2172+
Self::add_decimal_values::<Decimal128Type>(
2173+
lhs, p, s, *rhs, *rhs_p, *rhs_s,
2174+
)?;
2175+
}
2176+
(
2177+
ScalarValue::Decimal256(lhs, p, s),
2178+
ScalarValue::Decimal256(rhs, rhs_p, rhs_s),
2179+
) => {
2180+
Self::add_decimal_values::<Decimal256Type>(
2181+
lhs, p, s, *rhs, *rhs_p, *rhs_s,
2182+
)?;
2183+
}
2184+
_ => return Ok(false),
2185+
}
2186+
2187+
Ok(true)
2188+
}
2189+
2190+
#[inline]
2191+
pub(crate) fn try_add_wrapping_in_place(
2192+
&mut self,
2193+
other: &ScalarValue,
2194+
) -> Result<bool> {
2195+
self.try_add_in_place_impl(other, false)
2196+
}
2197+
2198+
#[inline]
2199+
pub(crate) fn try_add_checked_in_place(
2200+
&mut self,
2201+
other: &ScalarValue,
2202+
) -> Result<bool> {
2203+
self.try_add_in_place_impl(other, true)
2204+
}
2205+
20032206
/// Calculate arithmetic negation for a scalar value
20042207
pub fn arithmetic_negate(&self) -> Result<Self> {
20052208
fn neg_checked_with_ctx<T: ArrowNativeTypeOp>(
@@ -2135,7 +2338,16 @@ impl ScalarValue {
21352338
/// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code
21362339
/// should operate on Arrays directly, using vectorized array kernels
21372340
pub fn add<T: Borrow<ScalarValue>>(&self, other: T) -> Result<ScalarValue> {
2138-
let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?;
2341+
let other = other.borrow();
2342+
if Self::can_use_direct_add(self, other) {
2343+
let mut result = self.clone();
2344+
if result.try_add_wrapping_in_place(other)? {
2345+
return Ok(result);
2346+
}
2347+
debug_assert!(false, "fast-path eligibility drifted from implementation");
2348+
}
2349+
2350+
let r = add_wrapping(&self.to_scalar()?, &other.to_scalar()?)?;
21392351
Self::try_from_array(r.as_ref(), 0)
21402352
}
21412353

@@ -2144,7 +2356,16 @@ impl ScalarValue {
21442356
/// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code
21452357
/// should operate on Arrays directly, using vectorized array kernels
21462358
pub fn add_checked<T: Borrow<ScalarValue>>(&self, other: T) -> Result<ScalarValue> {
2147-
let r = add(&self.to_scalar()?, &other.borrow().to_scalar()?)?;
2359+
let other = other.borrow();
2360+
if Self::can_use_direct_add(self, other) {
2361+
let mut result = self.clone();
2362+
if result.try_add_checked_in_place(other)? {
2363+
return Ok(result);
2364+
}
2365+
debug_assert!(false, "fast-path eligibility drifted from implementation");
2366+
}
2367+
2368+
let r = add(&self.to_scalar()?, &other.to_scalar()?)?;
21482369
Self::try_from_array(r.as_ref(), 0)
21492370
}
21502371

@@ -5943,6 +6164,68 @@ mod tests {
59436164
Ok(())
59446165
}
59456166

6167+
#[test]
6168+
fn scalar_add_trait_null_test() -> Result<()> {
6169+
let int_value = ScalarValue::Int32(Some(42));
6170+
6171+
assert_eq!(
6172+
int_value.add(ScalarValue::Int32(None))?,
6173+
ScalarValue::Int32(None)
6174+
);
6175+
6176+
Ok(())
6177+
}
6178+
6179+
#[test]
6180+
fn scalar_add_trait_wrapping_overflow_test() -> Result<()> {
6181+
let int_value = ScalarValue::Int32(Some(i32::MAX));
6182+
let one = ScalarValue::Int32(Some(1));
6183+
6184+
assert_eq!(int_value.add(one)?, ScalarValue::Int32(Some(i32::MIN)));
6185+
6186+
Ok(())
6187+
}
6188+
6189+
#[test]
6190+
fn scalar_add_trait_decimal_scale_test() -> Result<()> {
6191+
let decimal = ScalarValue::Decimal128(Some(123), 10, 2);
6192+
let decimal_2 = ScalarValue::Decimal128(Some(4), 9, 1);
6193+
6194+
assert_eq!(
6195+
decimal.add(decimal_2)?,
6196+
ScalarValue::Decimal128(Some(163), 11, 2)
6197+
);
6198+
6199+
Ok(())
6200+
}
6201+
6202+
#[test]
6203+
fn scalar_add_trait_decimal256_scale_test() -> Result<()> {
6204+
let decimal = ScalarValue::Decimal256(Some(i256::from(123)), 10, 2);
6205+
let decimal_2 = ScalarValue::Decimal256(Some(i256::from(4)), 9, 1);
6206+
6207+
assert_eq!(
6208+
decimal.add(decimal_2)?,
6209+
ScalarValue::Decimal256(Some(i256::from(163)), 11, 2)
6210+
);
6211+
6212+
Ok(())
6213+
}
6214+
6215+
#[test]
6216+
fn scalar_add_trait_decimal_negative_scale_test() -> Result<()> {
6217+
let decimal = ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, i8::MIN);
6218+
let decimal_2 =
6219+
ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, i8::MIN);
6220+
6221+
assert_eq!(
6222+
decimal.add(decimal_2)?,
6223+
ScalarValue::Decimal128(Some(2), DECIMAL128_MAX_PRECISION, i8::MIN)
6224+
);
6225+
6226+
Ok(())
6227+
}
6228+
59466229
#[test]
59476230
fn scalar_sub_trait_test() -> Result<()> {
59486231
let float_value = ScalarValue::Float64(Some(123.));
@@ -6042,6 +6325,43 @@ mod tests {
60426325
Ok(())
60436326
}
60446327

6328+
#[test]
6329+
fn scalar_decimal_add_overflow_test() {
6330+
check_scalar_decimal_add_overflow::<Decimal128Type>(
6331+
ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0),
6332+
ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, 0),
6333+
);
6334+
check_scalar_decimal_add_overflow::<Decimal256Type>(
6335+
ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0),
6336+
ScalarValue::Decimal256(Some(i256::ONE), DECIMAL256_MAX_PRECISION, 0),
6337+
);
6338+
}
6339+
6340+
#[test]
6341+
fn scalar_decimal_in_place_add_error_preserves_lhs() {
6342+
let mut lhs =
6343+
ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0);
6344+
let original = lhs.clone();
6345+
6346+
let err = lhs
6347+
.try_add_checked_in_place(&ScalarValue::Decimal128(
6348+
Some(1),
6349+
DECIMAL128_MAX_PRECISION,
6350+
0,
6351+
))
6352+
.unwrap_err()
6353+
.strip_backtrace();
6354+
6355+
assert_eq!(
6356+
err,
6357+
format!(
6358+
"Arrow error: Arithmetic overflow: Overflow happened on: {} + 1",
6359+
i128::MAX
6360+
)
6361+
);
6362+
assert_eq!(lhs, original);
6363+
}
6364+
60456365
// Verifies that ScalarValue has the same behavior with compute kernel when it overflows.
60466366
fn check_scalar_add_overflow<T>(left: ScalarValue, right: ScalarValue)
60476367
where
@@ -6058,6 +6378,22 @@ mod tests {
60586378
assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
60596379
}
60606380

6381+
// Verifies the decimal fast path preserves the same overflow behavior as Arrow kernels.
6382+
fn check_scalar_decimal_add_overflow<T>(left: ScalarValue, right: ScalarValue)
6383+
where
6384+
T: ArrowPrimitiveType,
6385+
{
6386+
let scalar_result = left.add(&right);
6387+
6388+
let left_array = left.to_array().expect("Failed to convert to array");
6389+
let right_array = right.to_array().expect("Failed to convert to array");
6390+
let arrow_left_array = left_array.as_primitive::<T>();
6391+
let arrow_right_array = right_array.as_primitive::<T>();
6392+
let arrow_result = add_wrapping(arrow_left_array, arrow_right_array);
6393+
6394+
assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
6395+
}
6396+
60616397
#[test]
60626398
fn test_interval_add_timestamp() -> Result<()> {
60636399
let interval = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {

0 commit comments

Comments
 (0)