Skip to content

Commit 60e70cd

Browse files
Dandandanclaude
andcommitted
Add in-place buffer reuse for arithmetic binary expression evaluation
When evaluating arithmetic binary expressions (+, -, *, /, %), reuse the left operand's buffer in-place when its reference count is 1, avoiding a buffer allocation. This benefits expression chains like a + b + c where intermediate results are consumed only once. Uses PrimitiveArray::unary_mut for array-scalar and into_builder for array-array cases. Only wrapping (infallible) ops use in-place mutation; checked ops fall back to standard Arrow kernels to avoid corrupting buffers on overflow. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5a86142 commit 60e70cd

2 files changed

Lines changed: 279 additions & 12 deletions

File tree

datafusion/physical-expr-common/src/datum.rs

Lines changed: 256 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::array::types::*;
19+
use arrow::array::ArrowNativeTypeOp;
1820
use arrow::array::BooleanArray;
19-
use arrow::array::{make_comparator, ArrayRef, Datum};
21+
use arrow::array::{
22+
make_comparator, Array, ArrayRef, ArrowPrimitiveType, Datum, PrimitiveArray,
23+
};
2024
use arrow::buffer::NullBuffer;
2125
use arrow::compute::SortOptions;
2226
use arrow::error::ArrowError;
@@ -53,6 +57,257 @@ pub fn apply(
5357
}
5458
}
5559

60+
/// Arithmetic operations that can be applied in-place on primitive arrays.
61+
#[derive(Debug, Copy, Clone)]
62+
pub enum ArithmeticOp {
63+
Add,
64+
AddWrapping,
65+
Sub,
66+
SubWrapping,
67+
Mul,
68+
MulWrapping,
69+
Div,
70+
Rem,
71+
}
72+
73+
/// Like [`apply`], but takes ownership of `ColumnarValue` inputs to enable
74+
/// in-place buffer reuse for arithmetic on primitive arrays.
75+
///
76+
/// When the left operand is a primitive array whose underlying buffer has a
77+
/// reference count of 1 (i.e. no other consumers), the arithmetic is performed
78+
/// in-place using [`PrimitiveArray::unary_mut`] or [`PrimitiveArray::try_unary_mut`],
79+
/// avoiding a buffer allocation. If in-place mutation is not possible (shared
80+
/// buffer, non-primitive type, etc.) this falls back to the standard Arrow
81+
/// compute kernel.
82+
pub fn apply_arithmetic(
83+
lhs: ColumnarValue,
84+
rhs: ColumnarValue,
85+
op: ArithmeticOp,
86+
) -> Result<ColumnarValue> {
87+
let f = arithmetic_op_to_fn(op);
88+
match (lhs, rhs) {
89+
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
90+
// Try in-place on left array with right array values
91+
match try_apply_inplace_array(left, &right, op) {
92+
Ok(result) => Ok(ColumnarValue::Array(result)),
93+
Err(left) => Ok(ColumnarValue::Array(f(&left, &right)?)),
94+
}
95+
}
96+
(ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => {
97+
Ok(ColumnarValue::Array(f(&left.to_scalar()?, &right)?))
98+
}
99+
(ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => {
100+
// Try in-place on left array with scalar right
101+
match try_apply_inplace_scalar(left, &right, op) {
102+
Ok(result) => Ok(ColumnarValue::Array(result)),
103+
Err(left) => Ok(ColumnarValue::Array(f(&left, &right.to_scalar()?)?)),
104+
}
105+
}
106+
(ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
107+
let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
108+
let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
109+
Ok(ColumnarValue::Scalar(scalar))
110+
}
111+
}
112+
}
113+
114+
fn arithmetic_op_to_fn(
115+
op: ArithmeticOp,
116+
) -> fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError> {
117+
use arrow::compute::kernels::numeric::*;
118+
match op {
119+
ArithmeticOp::Add => add,
120+
ArithmeticOp::AddWrapping => add_wrapping,
121+
ArithmeticOp::Sub => sub,
122+
ArithmeticOp::SubWrapping => sub_wrapping,
123+
ArithmeticOp::Mul => mul,
124+
ArithmeticOp::MulWrapping => mul_wrapping,
125+
ArithmeticOp::Div => div,
126+
ArithmeticOp::Rem => rem,
127+
}
128+
}
129+
130+
/// Try to apply arithmetic in-place on `left` array with a scalar `right`.
131+
/// Returns `Ok(result)` on success, or `Err(left)` if in-place not possible.
132+
fn try_apply_inplace_scalar(
133+
left: ArrayRef,
134+
right: &ScalarValue,
135+
op: ArithmeticOp,
136+
) -> Result<ArrayRef, ArrayRef> {
137+
if right.is_null() {
138+
return Err(left);
139+
}
140+
macro_rules! dispatch_inplace_scalar {
141+
($($arrow_type:ident),*) => {
142+
match left.data_type() {
143+
$(
144+
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
145+
let scalar_val = right
146+
.to_scalar()
147+
.map_err(|_| Arc::clone(&left))?;
148+
let scalar_arr = scalar_val.get().0;
149+
let rhs_val = scalar_arr
150+
.as_any()
151+
.downcast_ref::<PrimitiveArray<$arrow_type>>()
152+
.ok_or_else(|| Arc::clone(&left))?
153+
.value(0);
154+
try_inplace_unary::<$arrow_type>(left, rhs_val, op)
155+
}
156+
)*
157+
_ => Err(left),
158+
}
159+
};
160+
}
161+
dispatch_inplace_scalar!(
162+
Int8Type,
163+
Int16Type,
164+
Int32Type,
165+
Int64Type,
166+
UInt8Type,
167+
UInt16Type,
168+
UInt32Type,
169+
UInt64Type,
170+
Float16Type,
171+
Float32Type,
172+
Float64Type
173+
)
174+
}
175+
176+
/// Try to apply arithmetic in-place on `left` array using values from `right` array.
177+
/// Returns `Ok(result)` on success, or `Err(left)` if in-place not possible.
178+
fn try_apply_inplace_array(
179+
left: ArrayRef,
180+
right: &ArrayRef,
181+
op: ArithmeticOp,
182+
) -> Result<ArrayRef, ArrayRef> {
183+
if left.data_type() != right.data_type() {
184+
return Err(left);
185+
}
186+
macro_rules! dispatch_inplace_array {
187+
($($arrow_type:ident),*) => {
188+
match left.data_type() {
189+
$(
190+
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
191+
try_inplace_binary::<$arrow_type>(left, right, op)
192+
}
193+
)*
194+
_ => Err(left),
195+
}
196+
};
197+
}
198+
dispatch_inplace_array!(
199+
Int8Type,
200+
Int16Type,
201+
Int32Type,
202+
Int64Type,
203+
UInt8Type,
204+
UInt16Type,
205+
UInt32Type,
206+
UInt64Type,
207+
Float16Type,
208+
Float32Type,
209+
Float64Type
210+
)
211+
}
212+
213+
/// Attempt in-place unary (array op scalar) mutation on a PrimitiveArray.
214+
fn try_inplace_unary<T: ArrowPrimitiveType>(
215+
array: ArrayRef,
216+
scalar: T::Native,
217+
op: ArithmeticOp,
218+
) -> Result<ArrayRef, ArrayRef>
219+
where
220+
T::Native: ArrowNativeTypeOp,
221+
{
222+
// Clone the PrimitiveArray (cheap — shares the buffer via Arc)
223+
let primitive = array
224+
.as_any()
225+
.downcast_ref::<PrimitiveArray<T>>()
226+
.ok_or_else(|| Arc::clone(&array))?
227+
.clone();
228+
// Drop the ArrayRef so the buffer's refcount can drop to 1
229+
drop(array);
230+
231+
// Only attempt in-place for wrapping (infallible) operations.
232+
// Checked ops (Add, Sub, Mul, Div) can fail mid-way, corrupting the buffer.
233+
// Rem with zero divisor must also fall back for proper error reporting.
234+
type BinFn<N> = fn(N, N) -> N;
235+
let op_fn: Option<BinFn<T::Native>> = match op {
236+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
237+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
238+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
239+
ArithmeticOp::Rem if !scalar.is_zero() => Some(ArrowNativeTypeOp::mod_wrapping),
240+
_ => None,
241+
};
242+
243+
let Some(op_fn) = op_fn else {
244+
return Err(Arc::new(primitive));
245+
};
246+
247+
match primitive.unary_mut(|v| op_fn(v, scalar)) {
248+
Ok(result) => Ok(Arc::new(result)),
249+
Err(arr) => Err(Arc::new(arr)),
250+
}
251+
}
252+
253+
/// Attempt in-place binary (array op array) mutation on a PrimitiveArray.
254+
fn try_inplace_binary<T: ArrowPrimitiveType>(
255+
left: ArrayRef,
256+
right: &ArrayRef,
257+
op: ArithmeticOp,
258+
) -> Result<ArrayRef, ArrayRef>
259+
where
260+
T::Native: ArrowNativeTypeOp,
261+
{
262+
let right_primitive = right
263+
.as_any()
264+
.downcast_ref::<PrimitiveArray<T>>()
265+
.ok_or_else(|| Arc::clone(&left))?;
266+
267+
let left_primitive = left
268+
.as_any()
269+
.downcast_ref::<PrimitiveArray<T>>()
270+
.ok_or_else(|| Arc::clone(&left))?
271+
.clone();
272+
drop(left);
273+
274+
let mut builder = match left_primitive.into_builder() {
275+
Ok(b) => b,
276+
Err(arr) => return Err(Arc::new(arr)),
277+
};
278+
279+
// Only attempt in-place for wrapping (infallible) operations.
280+
type BinFn<N> = fn(N, N) -> N;
281+
let op_fn: Option<BinFn<T::Native>> = match op {
282+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
283+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
284+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
285+
_ => None,
286+
};
287+
288+
let Some(op_fn) = op_fn else {
289+
return Err(Arc::new(builder.finish()));
290+
};
291+
292+
let left_slice = builder.values_slice_mut();
293+
let right_values = right_primitive.values();
294+
295+
left_slice
296+
.iter_mut()
297+
.zip(right_values.iter())
298+
.for_each(|(l, r)| *l = op_fn(*l, *r));
299+
300+
// Merge null buffers from both sides
301+
let result = builder.finish();
302+
if right_primitive.nulls().is_some() {
303+
let merged = NullBuffer::union(result.nulls(), right_primitive.nulls());
304+
let result = PrimitiveArray::<T>::new(result.values().clone(), merged);
305+
Ok(Arc::new(result))
306+
} else {
307+
Ok(Arc::new(result))
308+
}
309+
}
310+
56311
/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
57312
pub fn apply_cmp(
58313
lhs: &ColumnarValue,

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ use datafusion_expr::statistics::{
4444
new_generic_from_binary_op, Distribution,
4545
};
4646
use datafusion_expr::{ColumnarValue, Operator};
47-
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
47+
use datafusion_physical_expr_common::datum::{
48+
apply_arithmetic, apply_cmp, apply_cmp_for_nested, ArithmeticOp,
49+
};
4850

4951
use kernels::{
5052
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
@@ -357,8 +359,6 @@ impl PhysicalExpr for BinaryExpr {
357359
}
358360

359361
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
360-
use arrow::compute::kernels::numeric::*;
361-
362362
// Evaluate left-hand side expression.
363363
let lhs = self.left.evaluate(batch)?;
364364

@@ -394,14 +394,26 @@ impl PhysicalExpr for BinaryExpr {
394394
}
395395

396396
match self.op {
397-
Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
398-
Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
399-
Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub),
400-
Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
401-
Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul),
402-
Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
403-
Operator::Divide => return apply(&lhs, &rhs, div),
404-
Operator::Modulo => return apply(&lhs, &rhs, rem),
397+
Operator::Plus if self.fail_on_overflow => {
398+
return apply_arithmetic(lhs, rhs, ArithmeticOp::Add)
399+
}
400+
Operator::Plus => {
401+
return apply_arithmetic(lhs, rhs, ArithmeticOp::AddWrapping)
402+
}
403+
Operator::Minus if self.fail_on_overflow => {
404+
return apply_arithmetic(lhs, rhs, ArithmeticOp::Sub)
405+
}
406+
Operator::Minus => {
407+
return apply_arithmetic(lhs, rhs, ArithmeticOp::SubWrapping)
408+
}
409+
Operator::Multiply if self.fail_on_overflow => {
410+
return apply_arithmetic(lhs, rhs, ArithmeticOp::Mul)
411+
}
412+
Operator::Multiply => {
413+
return apply_arithmetic(lhs, rhs, ArithmeticOp::MulWrapping)
414+
}
415+
Operator::Divide => return apply_arithmetic(lhs, rhs, ArithmeticOp::Div),
416+
Operator::Modulo => return apply_arithmetic(lhs, rhs, ArithmeticOp::Rem),
405417
Operator::Eq => return apply_cmp(&lhs, &rhs, eq),
406418
Operator::NotEq => return apply_cmp(&lhs, &rhs, neq),
407419
Operator::Lt => return apply_cmp(&lhs, &rhs, lt),

0 commit comments

Comments
 (0)