Skip to content

Commit 3bdcdf5

Browse files
perf: add in-place fast path for ScalarValue::add (#20959)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #20933 ## Rationale for this change Issue #20933 called out that `ScalarValue::add` was doing unnecessary work in hot paths such as statistics merging. The original suggestion was to mutate the lhs accumulator rather than alwaysbuilding a new scalar. This patch follows that direction and keeps the optimization in `ScalarValue` itself, which is a better long-term fit than duplicating specialized addition logic only in the stats code. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? This change adds an in-place fast path for `ScalarValue::add` / `add_checked` for same-type numeric and decimal values. The patch also updates statistics merging to reuse the existing `sum_value` accumulator instead of creating a new `ScalarValue` for each addition. <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Martin Grigorov <martin-g@users.noreply.github.com>
1 parent 415bd42 commit 3bdcdf5

3 files changed

Lines changed: 437 additions & 120 deletions

File tree

datafusion/common/src/scalar/mod.rs

Lines changed: 338 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,209 @@ impl ScalarValue {
20002000
}
20012001
}
20022002

2003+
#[inline]
2004+
fn can_use_direct_add(lhs: &ScalarValue, rhs: &ScalarValue) -> bool {
2005+
matches!(
2006+
(lhs, rhs),
2007+
(ScalarValue::Int8(_), ScalarValue::Int8(_))
2008+
| (ScalarValue::Int16(_), ScalarValue::Int16(_))
2009+
| (ScalarValue::Int32(_), ScalarValue::Int32(_))
2010+
| (ScalarValue::Int64(_), ScalarValue::Int64(_))
2011+
| (ScalarValue::UInt8(_), ScalarValue::UInt8(_))
2012+
| (ScalarValue::UInt16(_), ScalarValue::UInt16(_))
2013+
| (ScalarValue::UInt32(_), ScalarValue::UInt32(_))
2014+
| (ScalarValue::UInt64(_), ScalarValue::UInt64(_))
2015+
| (ScalarValue::Float16(_), ScalarValue::Float16(_))
2016+
| (ScalarValue::Float32(_), ScalarValue::Float32(_))
2017+
| (ScalarValue::Float64(_), ScalarValue::Float64(_))
2018+
| (
2019+
ScalarValue::Decimal32(_, _, _),
2020+
ScalarValue::Decimal32(_, _, _)
2021+
)
2022+
| (
2023+
ScalarValue::Decimal64(_, _, _),
2024+
ScalarValue::Decimal64(_, _, _)
2025+
)
2026+
| (
2027+
ScalarValue::Decimal128(_, _, _),
2028+
ScalarValue::Decimal128(_, _, _),
2029+
)
2030+
| (
2031+
ScalarValue::Decimal256(_, _, _),
2032+
ScalarValue::Decimal256(_, _, _),
2033+
)
2034+
)
2035+
}
2036+
2037+
#[inline]
2038+
fn add_optional<T: ArrowNativeTypeOp>(
2039+
lhs: &mut Option<T>,
2040+
rhs: Option<T>,
2041+
checked: bool,
2042+
) -> Result<()> {
2043+
match rhs {
2044+
Some(rhs) => {
2045+
if let Some(lhs) = lhs.as_mut() {
2046+
*lhs = if checked {
2047+
lhs.add_checked(rhs).map_err(|e| arrow_datafusion_err!(e))?
2048+
} else {
2049+
lhs.add_wrapping(rhs)
2050+
};
2051+
}
2052+
}
2053+
None => *lhs = None,
2054+
}
2055+
Ok(())
2056+
}
2057+
2058+
#[inline]
2059+
fn add_decimal_values<T: DecimalType>(
2060+
lhs_value: &mut Option<T::Native>,
2061+
lhs_precision: &mut u8,
2062+
lhs_scale: &mut i8,
2063+
rhs_value: Option<T::Native>,
2064+
rhs_precision: u8,
2065+
rhs_scale: i8,
2066+
) -> Result<()>
2067+
where
2068+
T::Native: ArrowNativeTypeOp,
2069+
{
2070+
Self::validate_decimal_or_internal_err::<T>(*lhs_precision, *lhs_scale)?;
2071+
Self::validate_decimal_or_internal_err::<T>(rhs_precision, rhs_scale)?;
2072+
2073+
let result_scale = (*lhs_scale).max(rhs_scale);
2074+
// Decimal scales can be negative, so use a wider signed type for the
2075+
// intermediate precision arithmetic.
2076+
let lhs_precision_delta = i16::from(*lhs_precision) - i16::from(*lhs_scale);
2077+
let rhs_precision_delta = i16::from(rhs_precision) - i16::from(rhs_scale);
2078+
let result_precision =
2079+
(i16::from(result_scale) + lhs_precision_delta.max(rhs_precision_delta) + 1)
2080+
.min(i16::from(T::MAX_PRECISION)) as u8;
2081+
2082+
Self::validate_decimal_or_internal_err::<T>(result_precision, result_scale)?;
2083+
2084+
let lhs_mul = T::Native::usize_as(10)
2085+
.pow_checked((result_scale - *lhs_scale) as u32)
2086+
.map_err(|e| arrow_datafusion_err!(e))?;
2087+
let rhs_mul = T::Native::usize_as(10)
2088+
.pow_checked((result_scale - rhs_scale) as u32)
2089+
.map_err(|e| arrow_datafusion_err!(e))?;
2090+
2091+
let result_value = match (*lhs_value, rhs_value) {
2092+
(Some(lhs_value), Some(rhs_value)) => Some(
2093+
lhs_value
2094+
.mul_checked(lhs_mul)
2095+
.and_then(|lhs| {
2096+
rhs_value
2097+
.mul_checked(rhs_mul)
2098+
.and_then(|rhs| lhs.add_checked(rhs))
2099+
})
2100+
.map_err(|e| arrow_datafusion_err!(e))?,
2101+
),
2102+
_ => None,
2103+
};
2104+
2105+
*lhs_value = result_value;
2106+
*lhs_precision = result_precision;
2107+
*lhs_scale = result_scale;
2108+
2109+
Ok(())
2110+
}
2111+
2112+
#[inline]
2113+
fn try_add_in_place_impl(
2114+
&mut self,
2115+
other: &ScalarValue,
2116+
checked: bool,
2117+
) -> Result<bool> {
2118+
match (self, other) {
2119+
(ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
2120+
Self::add_optional(lhs, *rhs, checked)?;
2121+
}
2122+
(ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
2123+
Self::add_optional(lhs, *rhs, checked)?;
2124+
}
2125+
(ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
2126+
Self::add_optional(lhs, *rhs, checked)?;
2127+
}
2128+
(ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
2129+
Self::add_optional(lhs, *rhs, checked)?;
2130+
}
2131+
(ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
2132+
Self::add_optional(lhs, *rhs, checked)?;
2133+
}
2134+
(ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
2135+
Self::add_optional(lhs, *rhs, checked)?;
2136+
}
2137+
(ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
2138+
Self::add_optional(lhs, *rhs, checked)?;
2139+
}
2140+
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
2141+
Self::add_optional(lhs, *rhs, checked)?;
2142+
}
2143+
(ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
2144+
Self::add_optional(lhs, *rhs, checked)?;
2145+
}
2146+
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
2147+
Self::add_optional(lhs, *rhs, checked)?;
2148+
}
2149+
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
2150+
Self::add_optional(lhs, *rhs, checked)?;
2151+
}
2152+
(
2153+
ScalarValue::Decimal32(lhs, p, s),
2154+
ScalarValue::Decimal32(rhs, rhs_p, rhs_s),
2155+
) => {
2156+
Self::add_decimal_values::<Decimal32Type>(
2157+
lhs, p, s, *rhs, *rhs_p, *rhs_s,
2158+
)?;
2159+
}
2160+
(
2161+
ScalarValue::Decimal64(lhs, p, s),
2162+
ScalarValue::Decimal64(rhs, rhs_p, rhs_s),
2163+
) => {
2164+
Self::add_decimal_values::<Decimal64Type>(
2165+
lhs, p, s, *rhs, *rhs_p, *rhs_s,
2166+
)?;
2167+
}
2168+
(
2169+
ScalarValue::Decimal128(lhs, p, s),
2170+
ScalarValue::Decimal128(rhs, rhs_p, rhs_s),
2171+
) => {
2172+
Self::add_decimal_values::<Decimal128Type>(
2173+
lhs, p, s, *rhs, *rhs_p, *rhs_s,
2174+
)?;
2175+
}
2176+
(
2177+
ScalarValue::Decimal256(lhs, p, s),
2178+
ScalarValue::Decimal256(rhs, rhs_p, rhs_s),
2179+
) => {
2180+
Self::add_decimal_values::<Decimal256Type>(
2181+
lhs, p, s, *rhs, *rhs_p, *rhs_s,
2182+
)?;
2183+
}
2184+
_ => return Ok(false),
2185+
}
2186+
2187+
Ok(true)
2188+
}
2189+
2190+
#[inline]
2191+
pub(crate) fn try_add_wrapping_in_place(
2192+
&mut self,
2193+
other: &ScalarValue,
2194+
) -> Result<bool> {
2195+
self.try_add_in_place_impl(other, false)
2196+
}
2197+
2198+
#[inline]
2199+
pub(crate) fn try_add_checked_in_place(
2200+
&mut self,
2201+
other: &ScalarValue,
2202+
) -> Result<bool> {
2203+
self.try_add_in_place_impl(other, true)
2204+
}
2205+
20032206
/// Calculate arithmetic negation for a scalar value
20042207
pub fn arithmetic_negate(&self) -> Result<Self> {
20052208
fn neg_checked_with_ctx<T: ArrowNativeTypeOp>(
@@ -2135,7 +2338,16 @@ impl ScalarValue {
21352338
/// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code
21362339
/// should operate on Arrays directly, using vectorized array kernels
21372340
pub fn add<T: Borrow<ScalarValue>>(&self, other: T) -> Result<ScalarValue> {
2138-
let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?;
2341+
let other = other.borrow();
2342+
if Self::can_use_direct_add(self, other) {
2343+
let mut result = self.clone();
2344+
if result.try_add_wrapping_in_place(other)? {
2345+
return Ok(result);
2346+
}
2347+
debug_assert!(false, "fast-path eligibility drifted from implementation");
2348+
}
2349+
2350+
let r = add_wrapping(&self.to_scalar()?, &other.to_scalar()?)?;
21392351
Self::try_from_array(r.as_ref(), 0)
21402352
}
21412353

@@ -2144,7 +2356,16 @@ impl ScalarValue {
21442356
/// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code
21452357
/// should operate on Arrays directly, using vectorized array kernels
21462358
pub fn add_checked<T: Borrow<ScalarValue>>(&self, other: T) -> Result<ScalarValue> {
2147-
let r = add(&self.to_scalar()?, &other.borrow().to_scalar()?)?;
2359+
let other = other.borrow();
2360+
if Self::can_use_direct_add(self, other) {
2361+
let mut result = self.clone();
2362+
if result.try_add_checked_in_place(other)? {
2363+
return Ok(result);
2364+
}
2365+
debug_assert!(false, "fast-path eligibility drifted from implementation");
2366+
}
2367+
2368+
let r = add(&self.to_scalar()?, &other.to_scalar()?)?;
21482369
Self::try_from_array(r.as_ref(), 0)
21492370
}
21502371

@@ -5943,6 +6164,68 @@ mod tests {
59436164
Ok(())
59446165
}
59456166

6167+
#[test]
6168+
fn scalar_add_trait_null_test() -> Result<()> {
6169+
let int_value = ScalarValue::Int32(Some(42));
6170+
6171+
assert_eq!(
6172+
int_value.add(ScalarValue::Int32(None))?,
6173+
ScalarValue::Int32(None)
6174+
);
6175+
6176+
Ok(())
6177+
}
6178+
6179+
#[test]
6180+
fn scalar_add_trait_wrapping_overflow_test() -> Result<()> {
6181+
let int_value = ScalarValue::Int32(Some(i32::MAX));
6182+
let one = ScalarValue::Int32(Some(1));
6183+
6184+
assert_eq!(int_value.add(one)?, ScalarValue::Int32(Some(i32::MIN)));
6185+
6186+
Ok(())
6187+
}
6188+
6189+
#[test]
6190+
fn scalar_add_trait_decimal_scale_test() -> Result<()> {
6191+
let decimal = ScalarValue::Decimal128(Some(123), 10, 2);
6192+
let decimal_2 = ScalarValue::Decimal128(Some(4), 9, 1);
6193+
6194+
assert_eq!(
6195+
decimal.add(decimal_2)?,
6196+
ScalarValue::Decimal128(Some(163), 11, 2)
6197+
);
6198+
6199+
Ok(())
6200+
}
6201+
6202+
#[test]
6203+
fn scalar_add_trait_decimal256_scale_test() -> Result<()> {
6204+
let decimal = ScalarValue::Decimal256(Some(i256::from(123)), 10, 2);
6205+
let decimal_2 = ScalarValue::Decimal256(Some(i256::from(4)), 9, 1);
6206+
6207+
assert_eq!(
6208+
decimal.add(decimal_2)?,
6209+
ScalarValue::Decimal256(Some(i256::from(163)), 11, 2)
6210+
);
6211+
6212+
Ok(())
6213+
}
6214+
6215+
#[test]
6216+
fn scalar_add_trait_decimal_negative_scale_test() -> Result<()> {
6217+
let decimal = ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, i8::MIN);
6218+
let decimal_2 =
6219+
ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, i8::MIN);
6220+
6221+
assert_eq!(
6222+
decimal.add(decimal_2)?,
6223+
ScalarValue::Decimal128(Some(2), DECIMAL128_MAX_PRECISION, i8::MIN)
6224+
);
6225+
6226+
Ok(())
6227+
}
6228+
59466229
#[test]
59476230
fn scalar_sub_trait_test() -> Result<()> {
59486231
let float_value = ScalarValue::Float64(Some(123.));
@@ -6042,6 +6325,43 @@ mod tests {
60426325
Ok(())
60436326
}
60446327

6328+
#[test]
6329+
fn scalar_decimal_add_overflow_test() {
6330+
check_scalar_decimal_add_overflow::<Decimal128Type>(
6331+
ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0),
6332+
ScalarValue::Decimal128(Some(1), DECIMAL128_MAX_PRECISION, 0),
6333+
);
6334+
check_scalar_decimal_add_overflow::<Decimal256Type>(
6335+
ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0),
6336+
ScalarValue::Decimal256(Some(i256::ONE), DECIMAL256_MAX_PRECISION, 0),
6337+
);
6338+
}
6339+
6340+
#[test]
6341+
fn scalar_decimal_in_place_add_error_preserves_lhs() {
6342+
let mut lhs =
6343+
ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0);
6344+
let original = lhs.clone();
6345+
6346+
let err = lhs
6347+
.try_add_checked_in_place(&ScalarValue::Decimal128(
6348+
Some(1),
6349+
DECIMAL128_MAX_PRECISION,
6350+
0,
6351+
))
6352+
.unwrap_err()
6353+
.strip_backtrace();
6354+
6355+
assert_eq!(
6356+
err,
6357+
format!(
6358+
"Arrow error: Arithmetic overflow: Overflow happened on: {} + 1",
6359+
i128::MAX
6360+
)
6361+
);
6362+
assert_eq!(lhs, original);
6363+
}
6364+
60456365
// Verifies that ScalarValue has the same behavior with compute kernel when it overflows.
60466366
fn check_scalar_add_overflow<T>(left: ScalarValue, right: ScalarValue)
60476367
where
@@ -6058,6 +6378,22 @@ mod tests {
60586378
assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
60596379
}
60606380

6381+
// Verifies the decimal fast path preserves the same overflow behavior as Arrow kernels.
6382+
fn check_scalar_decimal_add_overflow<T>(left: ScalarValue, right: ScalarValue)
6383+
where
6384+
T: ArrowPrimitiveType,
6385+
{
6386+
let scalar_result = left.add(&right);
6387+
6388+
let left_array = left.to_array().expect("Failed to convert to array");
6389+
let right_array = right.to_array().expect("Failed to convert to array");
6390+
let arrow_left_array = left_array.as_primitive::<T>();
6391+
let arrow_right_array = right_array.as_primitive::<T>();
6392+
let arrow_result = add_wrapping(arrow_left_array, arrow_right_array);
6393+
6394+
assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
6395+
}
6396+
60616397
#[test]
60626398
fn test_interval_add_timestamp() -> Result<()> {
60636399
let interval = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {

0 commit comments

Comments
 (0)