Skip to content

Commit 4c6f444

Browse files
Dandandanclaude
andcommitted
Also try in-place mutation on the right operand
When the left operand's buffer cannot be reused (shared reference), try the right operand for in-place mutation. This covers cases like Scalar-Array and Array-Array where the right side has refcount 1. For non-commutative ops (sub, rem), the argument order is preserved correctly: result[i] = op(left[i], right[i]). Also refactors type dispatch into shared macros. Decimal types are excluded from in-place mutation because the result precision/scale differs from the input. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e73a965 commit 4c6f444

1 file changed

Lines changed: 242 additions & 62 deletions

File tree

  • datafusion/physical-expr-common/src

datafusion/physical-expr-common/src/datum.rs

Lines changed: 242 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,21 @@ pub fn apply_arithmetic(
8989
let f = arithmetic_op_to_fn(op);
9090
match (lhs, rhs) {
9191
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
92-
// Try in-place on left array with right array values
92+
// Try in-place on left array, then right array
9393
match try_apply_inplace_array(left, &right, op) {
9494
Ok(result) => Ok(ColumnarValue::Array(result)),
95-
Err(left) => Ok(ColumnarValue::Array(f(&left, &right)?)),
95+
Err(left) => match try_apply_inplace_array_rhs(&left, right, op) {
96+
Ok(result) => Ok(ColumnarValue::Array(result)),
97+
Err(right) => Ok(ColumnarValue::Array(f(&left, &right)?)),
98+
},
9699
}
97100
}
98101
(ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => {
99-
Ok(ColumnarValue::Array(f(&left.to_scalar()?, &right)?))
102+
// Try in-place on right array with scalar left (flipped)
103+
match try_apply_inplace_scalar_rhs(right, &left, op) {
104+
Ok(result) => Ok(ColumnarValue::Array(result)),
105+
Err(right) => Ok(ColumnarValue::Array(f(&left.to_scalar()?, &right)?)),
106+
}
100107
}
101108
(ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => {
102109
// Try in-place on left array with scalar right
@@ -129,6 +136,110 @@ fn arithmetic_op_to_fn(
129136
}
130137
}
131138

139+
/// Dispatches an in-place unary (array op scalar) operation across all supported primitive types.
140+
/// `$arr` is the ArrayRef, `$scalar_value` is the &ScalarValue, `$op` is ArithmeticOp,
141+
/// `$fn_name` is the function to call (try_inplace_unary or try_inplace_unary_rhs).
142+
macro_rules! dispatch_inplace_unary {
143+
($arr:expr, $scalar_value:expr, $op:expr, $fn_name:ident) => {{
144+
macro_rules! do_dispatch {
145+
($arrow_type:ty, $arr_inner:expr) => {{
146+
let scalar_val = $scalar_value
147+
.to_scalar()
148+
.map_err(|_| Arc::clone(&$arr_inner))?;
149+
let scalar_arr = scalar_val.get().0;
150+
let rhs_val = scalar_arr
151+
.as_any()
152+
.downcast_ref::<PrimitiveArray<$arrow_type>>()
153+
.ok_or_else(|| Arc::clone(&$arr_inner))?
154+
.value(0);
155+
$fn_name::<$arrow_type>($arr_inner, rhs_val, $op)
156+
}};
157+
}
158+
match $arr.data_type() {
159+
dt if dt == &<Int8Type as ArrowPrimitiveType>::DATA_TYPE => {
160+
do_dispatch!(Int8Type, $arr)
161+
}
162+
dt if dt == &<Int16Type as ArrowPrimitiveType>::DATA_TYPE => {
163+
do_dispatch!(Int16Type, $arr)
164+
}
165+
dt if dt == &<Int32Type as ArrowPrimitiveType>::DATA_TYPE => {
166+
do_dispatch!(Int32Type, $arr)
167+
}
168+
dt if dt == &<Int64Type as ArrowPrimitiveType>::DATA_TYPE => {
169+
do_dispatch!(Int64Type, $arr)
170+
}
171+
dt if dt == &<UInt8Type as ArrowPrimitiveType>::DATA_TYPE => {
172+
do_dispatch!(UInt8Type, $arr)
173+
}
174+
dt if dt == &<UInt16Type as ArrowPrimitiveType>::DATA_TYPE => {
175+
do_dispatch!(UInt16Type, $arr)
176+
}
177+
dt if dt == &<UInt32Type as ArrowPrimitiveType>::DATA_TYPE => {
178+
do_dispatch!(UInt32Type, $arr)
179+
}
180+
dt if dt == &<UInt64Type as ArrowPrimitiveType>::DATA_TYPE => {
181+
do_dispatch!(UInt64Type, $arr)
182+
}
183+
dt if dt == &<Float16Type as ArrowPrimitiveType>::DATA_TYPE => {
184+
do_dispatch!(Float16Type, $arr)
185+
}
186+
dt if dt == &<Float32Type as ArrowPrimitiveType>::DATA_TYPE => {
187+
do_dispatch!(Float32Type, $arr)
188+
}
189+
dt if dt == &<Float64Type as ArrowPrimitiveType>::DATA_TYPE => {
190+
do_dispatch!(Float64Type, $arr)
191+
}
192+
// Decimal types excluded: result precision/scale differs from input
193+
_ => Err($arr),
194+
}
195+
}};
196+
}
197+
198+
/// Dispatches an in-place binary (array op array) operation across all supported primitive types.
199+
/// `$arr` is the ArrayRef to mutate, `$other` is the other ArrayRef, `$op` is ArithmeticOp,
200+
/// `$fn_name` is the function to call (try_inplace_binary or try_inplace_binary_rhs).
201+
macro_rules! dispatch_inplace_binary {
202+
($arr:expr, $other:expr, $op:expr, $fn_name:ident) => {{
203+
match $arr.data_type() {
204+
dt if dt == &<Int8Type as ArrowPrimitiveType>::DATA_TYPE => {
205+
$fn_name::<Int8Type>($arr, $other, $op)
206+
}
207+
dt if dt == &<Int16Type as ArrowPrimitiveType>::DATA_TYPE => {
208+
$fn_name::<Int16Type>($arr, $other, $op)
209+
}
210+
dt if dt == &<Int32Type as ArrowPrimitiveType>::DATA_TYPE => {
211+
$fn_name::<Int32Type>($arr, $other, $op)
212+
}
213+
dt if dt == &<Int64Type as ArrowPrimitiveType>::DATA_TYPE => {
214+
$fn_name::<Int64Type>($arr, $other, $op)
215+
}
216+
dt if dt == &<UInt8Type as ArrowPrimitiveType>::DATA_TYPE => {
217+
$fn_name::<UInt8Type>($arr, $other, $op)
218+
}
219+
dt if dt == &<UInt16Type as ArrowPrimitiveType>::DATA_TYPE => {
220+
$fn_name::<UInt16Type>($arr, $other, $op)
221+
}
222+
dt if dt == &<UInt32Type as ArrowPrimitiveType>::DATA_TYPE => {
223+
$fn_name::<UInt32Type>($arr, $other, $op)
224+
}
225+
dt if dt == &<UInt64Type as ArrowPrimitiveType>::DATA_TYPE => {
226+
$fn_name::<UInt64Type>($arr, $other, $op)
227+
}
228+
dt if dt == &<Float16Type as ArrowPrimitiveType>::DATA_TYPE => {
229+
$fn_name::<Float16Type>($arr, $other, $op)
230+
}
231+
dt if dt == &<Float32Type as ArrowPrimitiveType>::DATA_TYPE => {
232+
$fn_name::<Float32Type>($arr, $other, $op)
233+
}
234+
dt if dt == &<Float64Type as ArrowPrimitiveType>::DATA_TYPE => {
235+
$fn_name::<Float64Type>($arr, $other, $op)
236+
}
237+
// Decimal types excluded: result precision/scale differs from input
238+
_ => Err($arr),
239+
}
240+
}};
241+
}
242+
132243
/// Try to apply arithmetic in-place on `left` array with a scalar `right`.
133244
/// Returns `Ok(result)` on success, or `Err(left)` if in-place not possible.
134245
fn try_apply_inplace_scalar(
@@ -139,40 +250,7 @@ fn try_apply_inplace_scalar(
139250
if right.is_null() {
140251
return Err(left);
141252
}
142-
macro_rules! dispatch_inplace_scalar {
143-
($($arrow_type:ident),*) => {
144-
match left.data_type() {
145-
$(
146-
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
147-
let scalar_val = right
148-
.to_scalar()
149-
.map_err(|_| Arc::clone(&left))?;
150-
let scalar_arr = scalar_val.get().0;
151-
let rhs_val = scalar_arr
152-
.as_any()
153-
.downcast_ref::<PrimitiveArray<$arrow_type>>()
154-
.ok_or_else(|| Arc::clone(&left))?
155-
.value(0);
156-
try_inplace_unary::<$arrow_type>(left, rhs_val, op)
157-
}
158-
)*
159-
_ => Err(left),
160-
}
161-
};
162-
}
163-
dispatch_inplace_scalar!(
164-
Int8Type,
165-
Int16Type,
166-
Int32Type,
167-
Int64Type,
168-
UInt8Type,
169-
UInt16Type,
170-
UInt32Type,
171-
UInt64Type,
172-
Float16Type,
173-
Float32Type,
174-
Float64Type
175-
)
253+
dispatch_inplace_unary!(left, right, op, try_inplace_unary)
176254
}
177255

178256
/// Try to apply arithmetic in-place on `left` array using values from `right` array.
@@ -185,31 +263,7 @@ fn try_apply_inplace_array(
185263
if left.data_type() != right.data_type() {
186264
return Err(left);
187265
}
188-
macro_rules! dispatch_inplace_array {
189-
($($arrow_type:ident),*) => {
190-
match left.data_type() {
191-
$(
192-
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
193-
try_inplace_binary::<$arrow_type>(left, right, op)
194-
}
195-
)*
196-
_ => Err(left),
197-
}
198-
};
199-
}
200-
dispatch_inplace_array!(
201-
Int8Type,
202-
Int16Type,
203-
Int32Type,
204-
Int64Type,
205-
UInt8Type,
206-
UInt16Type,
207-
UInt32Type,
208-
UInt64Type,
209-
Float16Type,
210-
Float32Type,
211-
Float64Type
212-
)
266+
dispatch_inplace_binary!(left, right, op, try_inplace_binary)
213267
}
214268

215269
/// Attempt in-place unary (array op scalar) mutation on a PrimitiveArray.
@@ -310,6 +364,132 @@ where
310364
}
311365
}
312366

367+
/// Try to apply arithmetic in-place on `right` array with a scalar `left`.
368+
/// The operation is `result[i] = op(scalar_left, right[i])`, stored in `right`'s buffer.
369+
/// Returns `Ok(result)` on success, or `Err(right)` if in-place not possible.
370+
fn try_apply_inplace_scalar_rhs(
371+
right: ArrayRef,
372+
left: &ScalarValue,
373+
op: ArithmeticOp,
374+
) -> Result<ArrayRef, ArrayRef> {
375+
if left.is_null() {
376+
return Err(right);
377+
}
378+
dispatch_inplace_unary!(right, left, op, try_inplace_unary_rhs)
379+
}
380+
381+
/// Try to apply arithmetic in-place on `right` array using values from `left` array.
382+
/// The operation is `result[i] = op(left[i], right[i])`, stored in `right`'s buffer.
383+
/// Returns `Ok(result)` on success, or `Err(right)` if in-place not possible.
384+
fn try_apply_inplace_array_rhs(
385+
left: &ArrayRef,
386+
right: ArrayRef,
387+
op: ArithmeticOp,
388+
) -> Result<ArrayRef, ArrayRef> {
389+
if left.data_type() != right.data_type() {
390+
return Err(right);
391+
}
392+
dispatch_inplace_binary!(right, left, op, try_inplace_binary_rhs)
393+
}
394+
395+
/// Attempt in-place mutation on the right PrimitiveArray: result[i] = op(scalar, right[i]).
396+
fn try_inplace_unary_rhs<T: ArrowPrimitiveType>(
397+
array: ArrayRef,
398+
scalar: T::Native,
399+
op: ArithmeticOp,
400+
) -> Result<ArrayRef, ArrayRef>
401+
where
402+
T::Native: ArrowNativeTypeOp,
403+
{
404+
let primitive = array
405+
.as_any()
406+
.downcast_ref::<PrimitiveArray<T>>()
407+
.ok_or_else(|| Arc::clone(&array))?
408+
.clone();
409+
drop(array);
410+
411+
// For right-side mutation: result = op(scalar, element)
412+
// Commutative ops: same as op(element, scalar)
413+
// Non-commutative: need reversed argument order
414+
type BinFn<N> = fn(N, N) -> N;
415+
let op_fn: Option<BinFn<T::Native>> = match op {
416+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
417+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
418+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
419+
ArithmeticOp::Rem if !scalar.is_zero() => Some(ArrowNativeTypeOp::mod_wrapping),
420+
_ => None,
421+
};
422+
423+
let Some(op_fn) = op_fn else {
424+
return Err(Arc::new(primitive));
425+
};
426+
427+
// Note: op(scalar, v) — scalar is the left operand
428+
match primitive.unary_mut(|v| op_fn(scalar, v)) {
429+
Ok(result) => Ok(Arc::new(result)),
430+
Err(arr) => Err(Arc::new(arr)),
431+
}
432+
}
433+
434+
/// Attempt in-place mutation on the right PrimitiveArray: result[i] = op(left[i], right[i]).
435+
/// Note: parameter order is (right_owned, left_ref) to match the dispatch_inplace_binary macro.
436+
fn try_inplace_binary_rhs<T: ArrowPrimitiveType>(
437+
right: ArrayRef,
438+
left: &ArrayRef,
439+
op: ArithmeticOp,
440+
) -> Result<ArrayRef, ArrayRef>
441+
where
442+
T::Native: ArrowNativeTypeOp,
443+
{
444+
let left_primitive = left
445+
.as_any()
446+
.downcast_ref::<PrimitiveArray<T>>()
447+
.ok_or_else(|| Arc::clone(&right))?;
448+
449+
let right_primitive = right
450+
.as_any()
451+
.downcast_ref::<PrimitiveArray<T>>()
452+
.ok_or_else(|| Arc::clone(&right))?
453+
.clone();
454+
drop(right);
455+
456+
let mut builder = match right_primitive.into_builder() {
457+
Ok(b) => b,
458+
Err(arr) => return Err(Arc::new(arr)),
459+
};
460+
461+
type BinFn<N> = fn(N, N) -> N;
462+
let op_fn: Option<BinFn<T::Native>> = match op {
463+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
464+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
465+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
466+
_ => None,
467+
};
468+
469+
let Some(op_fn) = op_fn else {
470+
return Err(Arc::new(builder.finish()));
471+
};
472+
473+
let right_slice = builder.values_slice_mut();
474+
let left_values = left_primitive.values();
475+
476+
// Note: op(left[i], right[i]) — left is the first operand
477+
right_slice
478+
.iter_mut()
479+
.zip(left_values.iter())
480+
.for_each(|(r, l)| *r = op_fn(*l, *r));
481+
482+
// Merge null buffers from both sides
483+
let result = builder.finish();
484+
if left_primitive.nulls().is_some() {
485+
let merged = NullBuffer::union(result.nulls(), left_primitive.nulls());
486+
let result = PrimitiveArray::<T>::new(result.values().clone(), merged);
487+
Ok(Arc::new(result))
488+
} else {
489+
Ok(Arc::new(result))
490+
}
491+
}
492+
313493
/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs`
314494
pub fn apply_cmp(
315495
op: Operator,

0 commit comments

Comments
 (0)