Skip to content

Commit e73a965

Browse files
Dandandanclaude
andcommitted
Add in-place buffer reuse for arithmetic binary expression evaluation
When evaluating arithmetic binary expressions (+, -, *, /, %), reuse the left operand's buffer in-place when its reference count is 1, avoiding a buffer allocation. This benefits expression chains like a + b + c where intermediate results are consumed only once. Uses PrimitiveArray::unary_mut for array-scalar and into_builder for array-array cases. Only wrapping (infallible) ops use in-place mutation; checked ops fall back to standard Arrow kernels to avoid corrupting buffers on overflow. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1e93a67 commit e73a965

2 files changed

Lines changed: 279 additions & 12 deletions

File tree

datafusion/physical-expr-common/src/datum.rs

Lines changed: 256 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::array::ArrowNativeTypeOp;
1819
use arrow::array::BooleanArray;
19-
use arrow::array::{ArrayRef, Datum, make_comparator};
20+
use arrow::array::types::*;
21+
use arrow::array::{
22+
Array, ArrayRef, ArrowPrimitiveType, Datum, PrimitiveArray, make_comparator,
23+
};
2024
use arrow::buffer::{BooleanBuffer, NullBuffer};
2125
use arrow::compute::kernels::cmp::{
2226
distinct, eq, gt, gt_eq, lt, lt_eq, neq, not_distinct,
@@ -55,6 +59,257 @@ pub fn apply(
5559
}
5660
}
5761

62+
/// Arithmetic operations that can be applied in-place on primitive arrays.
63+
#[derive(Debug, Copy, Clone)]
64+
pub enum ArithmeticOp {
65+
Add,
66+
AddWrapping,
67+
Sub,
68+
SubWrapping,
69+
Mul,
70+
MulWrapping,
71+
Div,
72+
Rem,
73+
}
74+
75+
/// Like [`apply`], but takes ownership of `ColumnarValue` inputs to enable
76+
/// in-place buffer reuse for arithmetic on primitive arrays.
77+
///
78+
/// When the left operand is a primitive array whose underlying buffer has a
79+
/// reference count of 1 (i.e. no other consumers), the arithmetic is performed
80+
/// in-place using [`PrimitiveArray::unary_mut`] or [`PrimitiveArray::try_unary_mut`],
81+
/// avoiding a buffer allocation. If in-place mutation is not possible (shared
82+
/// buffer, non-primitive type, etc.) this falls back to the standard Arrow
83+
/// compute kernel.
84+
pub fn apply_arithmetic(
85+
lhs: ColumnarValue,
86+
rhs: ColumnarValue,
87+
op: ArithmeticOp,
88+
) -> Result<ColumnarValue> {
89+
let f = arithmetic_op_to_fn(op);
90+
match (lhs, rhs) {
91+
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
92+
// Try in-place on left array with right array values
93+
match try_apply_inplace_array(left, &right, op) {
94+
Ok(result) => Ok(ColumnarValue::Array(result)),
95+
Err(left) => Ok(ColumnarValue::Array(f(&left, &right)?)),
96+
}
97+
}
98+
(ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => {
99+
Ok(ColumnarValue::Array(f(&left.to_scalar()?, &right)?))
100+
}
101+
(ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => {
102+
// Try in-place on left array with scalar right
103+
match try_apply_inplace_scalar(left, &right, op) {
104+
Ok(result) => Ok(ColumnarValue::Array(result)),
105+
Err(left) => Ok(ColumnarValue::Array(f(&left, &right.to_scalar()?)?)),
106+
}
107+
}
108+
(ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
109+
let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
110+
let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
111+
Ok(ColumnarValue::Scalar(scalar))
112+
}
113+
}
114+
}
115+
116+
fn arithmetic_op_to_fn(
117+
op: ArithmeticOp,
118+
) -> fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError> {
119+
use arrow::compute::kernels::numeric::*;
120+
match op {
121+
ArithmeticOp::Add => add,
122+
ArithmeticOp::AddWrapping => add_wrapping,
123+
ArithmeticOp::Sub => sub,
124+
ArithmeticOp::SubWrapping => sub_wrapping,
125+
ArithmeticOp::Mul => mul,
126+
ArithmeticOp::MulWrapping => mul_wrapping,
127+
ArithmeticOp::Div => div,
128+
ArithmeticOp::Rem => rem,
129+
}
130+
}
131+
132+
/// Try to apply arithmetic in-place on `left` array with a scalar `right`.
133+
/// Returns `Ok(result)` on success, or `Err(left)` if in-place not possible.
134+
fn try_apply_inplace_scalar(
135+
left: ArrayRef,
136+
right: &ScalarValue,
137+
op: ArithmeticOp,
138+
) -> Result<ArrayRef, ArrayRef> {
139+
if right.is_null() {
140+
return Err(left);
141+
}
142+
macro_rules! dispatch_inplace_scalar {
143+
($($arrow_type:ident),*) => {
144+
match left.data_type() {
145+
$(
146+
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
147+
let scalar_val = right
148+
.to_scalar()
149+
.map_err(|_| Arc::clone(&left))?;
150+
let scalar_arr = scalar_val.get().0;
151+
let rhs_val = scalar_arr
152+
.as_any()
153+
.downcast_ref::<PrimitiveArray<$arrow_type>>()
154+
.ok_or_else(|| Arc::clone(&left))?
155+
.value(0);
156+
try_inplace_unary::<$arrow_type>(left, rhs_val, op)
157+
}
158+
)*
159+
_ => Err(left),
160+
}
161+
};
162+
}
163+
dispatch_inplace_scalar!(
164+
Int8Type,
165+
Int16Type,
166+
Int32Type,
167+
Int64Type,
168+
UInt8Type,
169+
UInt16Type,
170+
UInt32Type,
171+
UInt64Type,
172+
Float16Type,
173+
Float32Type,
174+
Float64Type
175+
)
176+
}
177+
178+
/// Try to apply arithmetic in-place on `left` array using values from `right` array.
179+
/// Returns `Ok(result)` on success, or `Err(left)` if in-place not possible.
180+
fn try_apply_inplace_array(
181+
left: ArrayRef,
182+
right: &ArrayRef,
183+
op: ArithmeticOp,
184+
) -> Result<ArrayRef, ArrayRef> {
185+
if left.data_type() != right.data_type() {
186+
return Err(left);
187+
}
188+
macro_rules! dispatch_inplace_array {
189+
($($arrow_type:ident),*) => {
190+
match left.data_type() {
191+
$(
192+
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
193+
try_inplace_binary::<$arrow_type>(left, right, op)
194+
}
195+
)*
196+
_ => Err(left),
197+
}
198+
};
199+
}
200+
dispatch_inplace_array!(
201+
Int8Type,
202+
Int16Type,
203+
Int32Type,
204+
Int64Type,
205+
UInt8Type,
206+
UInt16Type,
207+
UInt32Type,
208+
UInt64Type,
209+
Float16Type,
210+
Float32Type,
211+
Float64Type
212+
)
213+
}
214+
215+
/// Attempt in-place unary (array op scalar) mutation on a PrimitiveArray.
216+
fn try_inplace_unary<T: ArrowPrimitiveType>(
217+
array: ArrayRef,
218+
scalar: T::Native,
219+
op: ArithmeticOp,
220+
) -> Result<ArrayRef, ArrayRef>
221+
where
222+
T::Native: ArrowNativeTypeOp,
223+
{
224+
// Clone the PrimitiveArray (cheap — shares the buffer via Arc)
225+
let primitive = array
226+
.as_any()
227+
.downcast_ref::<PrimitiveArray<T>>()
228+
.ok_or_else(|| Arc::clone(&array))?
229+
.clone();
230+
// Drop the ArrayRef so the buffer's refcount can drop to 1
231+
drop(array);
232+
233+
// Only attempt in-place for wrapping (infallible) operations.
234+
// Checked ops (Add, Sub, Mul, Div) can fail mid-way, corrupting the buffer.
235+
// Rem with zero divisor must also fall back for proper error reporting.
236+
type BinFn<N> = fn(N, N) -> N;
237+
let op_fn: Option<BinFn<T::Native>> = match op {
238+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
239+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
240+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
241+
ArithmeticOp::Rem if !scalar.is_zero() => Some(ArrowNativeTypeOp::mod_wrapping),
242+
_ => None,
243+
};
244+
245+
let Some(op_fn) = op_fn else {
246+
return Err(Arc::new(primitive));
247+
};
248+
249+
match primitive.unary_mut(|v| op_fn(v, scalar)) {
250+
Ok(result) => Ok(Arc::new(result)),
251+
Err(arr) => Err(Arc::new(arr)),
252+
}
253+
}
254+
255+
/// Attempt in-place binary (array op array) mutation on a PrimitiveArray.
256+
fn try_inplace_binary<T: ArrowPrimitiveType>(
257+
left: ArrayRef,
258+
right: &ArrayRef,
259+
op: ArithmeticOp,
260+
) -> Result<ArrayRef, ArrayRef>
261+
where
262+
T::Native: ArrowNativeTypeOp,
263+
{
264+
let right_primitive = right
265+
.as_any()
266+
.downcast_ref::<PrimitiveArray<T>>()
267+
.ok_or_else(|| Arc::clone(&left))?;
268+
269+
let left_primitive = left
270+
.as_any()
271+
.downcast_ref::<PrimitiveArray<T>>()
272+
.ok_or_else(|| Arc::clone(&left))?
273+
.clone();
274+
drop(left);
275+
276+
let mut builder = match left_primitive.into_builder() {
277+
Ok(b) => b,
278+
Err(arr) => return Err(Arc::new(arr)),
279+
};
280+
281+
// Only attempt in-place for wrapping (infallible) operations.
282+
type BinFn<N> = fn(N, N) -> N;
283+
let op_fn: Option<BinFn<T::Native>> = match op {
284+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
285+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
286+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
287+
_ => None,
288+
};
289+
290+
let Some(op_fn) = op_fn else {
291+
return Err(Arc::new(builder.finish()));
292+
};
293+
294+
let left_slice = builder.values_slice_mut();
295+
let right_values = right_primitive.values();
296+
297+
left_slice
298+
.iter_mut()
299+
.zip(right_values.iter())
300+
.for_each(|(l, r)| *l = op_fn(*l, *r));
301+
302+
// Merge null buffers from both sides
303+
let result = builder.finish();
304+
if right_primitive.nulls().is_some() {
305+
let merged = NullBuffer::union(result.nulls(), right_primitive.nulls());
306+
let result = PrimitiveArray::<T>::new(result.values().clone(), merged);
307+
Ok(Arc::new(result))
308+
} else {
309+
Ok(Arc::new(result))
310+
}
311+
}
312+
58313
/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs`
59314
pub fn apply_cmp(
60315
op: Operator,

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ use datafusion_expr::statistics::{
4040
create_bernoulli_from_comparison, new_generic_from_binary_op,
4141
};
4242
use datafusion_expr::{ColumnarValue, Operator};
43-
use datafusion_physical_expr_common::datum::{apply, apply_cmp};
43+
use datafusion_physical_expr_common::datum::{
44+
ArithmeticOp, apply, apply_arithmetic, apply_cmp,
45+
};
4446

4547
use kernels::{
4648
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
@@ -271,8 +273,6 @@ impl PhysicalExpr for BinaryExpr {
271273
}
272274

273275
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
274-
use arrow::compute::kernels::numeric::*;
275-
276276
// Evaluate left-hand side expression.
277277
let lhs = self.left.evaluate(batch)?;
278278

@@ -338,19 +338,31 @@ impl PhysicalExpr for BinaryExpr {
338338
let input_schema = schema.as_ref();
339339

340340
match self.op {
341-
Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
342-
Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
341+
Operator::Plus if self.fail_on_overflow => {
342+
return apply_arithmetic(lhs, rhs, ArithmeticOp::Add);
343+
}
344+
Operator::Plus => {
345+
return apply_arithmetic(lhs, rhs, ArithmeticOp::AddWrapping);
346+
}
343347
// Special case: Date - Date returns Int64 (days difference)
344348
// This aligns with PostgreSQL, DuckDB, and MySQL behavior
345349
Operator::Minus if is_date_minus_date(&left_data_type, &right_data_type) => {
346350
return apply_date_subtraction(&lhs, &rhs);
347351
}
348-
Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub),
349-
Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
350-
Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul),
351-
Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
352-
Operator::Divide => return apply(&lhs, &rhs, div),
353-
Operator::Modulo => return apply(&lhs, &rhs, rem),
352+
Operator::Minus if self.fail_on_overflow => {
353+
return apply_arithmetic(lhs, rhs, ArithmeticOp::Sub);
354+
}
355+
Operator::Minus => {
356+
return apply_arithmetic(lhs, rhs, ArithmeticOp::SubWrapping);
357+
}
358+
Operator::Multiply if self.fail_on_overflow => {
359+
return apply_arithmetic(lhs, rhs, ArithmeticOp::Mul);
360+
}
361+
Operator::Multiply => {
362+
return apply_arithmetic(lhs, rhs, ArithmeticOp::MulWrapping);
363+
}
364+
Operator::Divide => return apply_arithmetic(lhs, rhs, ArithmeticOp::Div),
365+
Operator::Modulo => return apply_arithmetic(lhs, rhs, ArithmeticOp::Rem),
354366

355367
Operator::Eq
356368
| Operator::NotEq

0 commit comments

Comments
 (0)