Skip to content

Commit a1e5de3

Browse files
Dandandanclaude
andcommitted
Also try in-place mutation on the right operand
When the left operand's buffer cannot be reused (shared reference), try the right operand for in-place mutation. This covers cases like Scalar-Array and Array-Array where the right side has refcount 1. For non-commutative ops (sub, rem), the argument order is preserved correctly: result[i] = op(left[i], right[i]). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 60e70cd commit a1e5de3

1 file changed

Lines changed: 192 additions & 3 deletions

File tree

  • datafusion/physical-expr-common/src

datafusion/physical-expr-common/src/datum.rs

Lines changed: 192 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,21 @@ pub fn apply_arithmetic(
8787
let f = arithmetic_op_to_fn(op);
8888
match (lhs, rhs) {
8989
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
90-
// Try in-place on left array with right array values
90+
// Try in-place on left array, then right array
9191
match try_apply_inplace_array(left, &right, op) {
9292
Ok(result) => Ok(ColumnarValue::Array(result)),
93-
Err(left) => Ok(ColumnarValue::Array(f(&left, &right)?)),
93+
Err(left) => match try_apply_inplace_array_rhs(&left, right, op) {
94+
Ok(result) => Ok(ColumnarValue::Array(result)),
95+
Err(right) => Ok(ColumnarValue::Array(f(&left, &right)?)),
96+
},
9497
}
9598
}
9699
(ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => {
97-
Ok(ColumnarValue::Array(f(&left.to_scalar()?, &right)?))
100+
// Try in-place on right array with scalar left (flipped)
101+
match try_apply_inplace_scalar_rhs(right, &left, op) {
102+
Ok(result) => Ok(ColumnarValue::Array(result)),
103+
Err(right) => Ok(ColumnarValue::Array(f(&left.to_scalar()?, &right)?)),
104+
}
98105
}
99106
(ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => {
100107
// Try in-place on left array with scalar right
@@ -308,6 +315,188 @@ where
308315
}
309316
}
310317

318+
/// Try to apply arithmetic in-place on `right` array with a scalar `left`.
319+
/// The operation is `result[i] = op(scalar_left, right[i])`, stored in `right`'s buffer.
320+
/// Returns `Ok(result)` on success, or `Err(right)` if in-place not possible.
321+
fn try_apply_inplace_scalar_rhs(
322+
right: ArrayRef,
323+
left: &ScalarValue,
324+
op: ArithmeticOp,
325+
) -> Result<ArrayRef, ArrayRef> {
326+
if left.is_null() {
327+
return Err(right);
328+
}
329+
macro_rules! dispatch_inplace_scalar_rhs {
330+
($($arrow_type:ident),*) => {
331+
match right.data_type() {
332+
$(
333+
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
334+
let scalar_val = left
335+
.to_scalar()
336+
.map_err(|_| Arc::clone(&right))?;
337+
let scalar_arr = scalar_val.get().0;
338+
let lhs_val = scalar_arr
339+
.as_any()
340+
.downcast_ref::<PrimitiveArray<$arrow_type>>()
341+
.ok_or_else(|| Arc::clone(&right))?
342+
.value(0);
343+
try_inplace_unary_rhs::<$arrow_type>(right, lhs_val, op)
344+
}
345+
)*
346+
_ => Err(right),
347+
}
348+
};
349+
}
350+
dispatch_inplace_scalar_rhs!(
351+
Int8Type,
352+
Int16Type,
353+
Int32Type,
354+
Int64Type,
355+
UInt8Type,
356+
UInt16Type,
357+
UInt32Type,
358+
UInt64Type,
359+
Float16Type,
360+
Float32Type,
361+
Float64Type
362+
)
363+
}
364+
365+
/// Try to apply arithmetic in-place on `right` array using values from `left` array.
366+
/// The operation is `result[i] = op(left[i], right[i])`, stored in `right`'s buffer.
367+
/// Returns `Ok(result)` on success, or `Err(right)` if in-place not possible.
368+
fn try_apply_inplace_array_rhs(
369+
left: &ArrayRef,
370+
right: ArrayRef,
371+
op: ArithmeticOp,
372+
) -> Result<ArrayRef, ArrayRef> {
373+
if left.data_type() != right.data_type() {
374+
return Err(right);
375+
}
376+
macro_rules! dispatch_inplace_array_rhs {
377+
($($arrow_type:ident),*) => {
378+
match right.data_type() {
379+
$(
380+
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
381+
try_inplace_binary_rhs::<$arrow_type>(left, right, op)
382+
}
383+
)*
384+
_ => Err(right),
385+
}
386+
};
387+
}
388+
dispatch_inplace_array_rhs!(
389+
Int8Type,
390+
Int16Type,
391+
Int32Type,
392+
Int64Type,
393+
UInt8Type,
394+
UInt16Type,
395+
UInt32Type,
396+
UInt64Type,
397+
Float16Type,
398+
Float32Type,
399+
Float64Type
400+
)
401+
}
402+
403+
/// Attempt in-place mutation on the right PrimitiveArray: result[i] = op(scalar, right[i]).
404+
fn try_inplace_unary_rhs<T: ArrowPrimitiveType>(
405+
array: ArrayRef,
406+
scalar: T::Native,
407+
op: ArithmeticOp,
408+
) -> Result<ArrayRef, ArrayRef>
409+
where
410+
T::Native: ArrowNativeTypeOp,
411+
{
412+
let primitive = array
413+
.as_any()
414+
.downcast_ref::<PrimitiveArray<T>>()
415+
.ok_or_else(|| Arc::clone(&array))?
416+
.clone();
417+
drop(array);
418+
419+
// For right-side mutation: result = op(scalar, element)
420+
// Commutative ops: same as op(element, scalar)
421+
// Non-commutative: need reversed argument order
422+
type BinFn<N> = fn(N, N) -> N;
423+
let op_fn: Option<BinFn<T::Native>> = match op {
424+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
425+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
426+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
427+
ArithmeticOp::Rem if !scalar.is_zero() => Some(ArrowNativeTypeOp::mod_wrapping),
428+
_ => None,
429+
};
430+
431+
let Some(op_fn) = op_fn else {
432+
return Err(Arc::new(primitive));
433+
};
434+
435+
// Note: op(scalar, v) — scalar is the left operand
436+
match primitive.unary_mut(|v| op_fn(scalar, v)) {
437+
Ok(result) => Ok(Arc::new(result)),
438+
Err(arr) => Err(Arc::new(arr)),
439+
}
440+
}
441+
442+
/// Attempt in-place mutation on the right PrimitiveArray: result[i] = op(left[i], right[i]).
443+
fn try_inplace_binary_rhs<T: ArrowPrimitiveType>(
444+
left: &ArrayRef,
445+
right: ArrayRef,
446+
op: ArithmeticOp,
447+
) -> Result<ArrayRef, ArrayRef>
448+
where
449+
T::Native: ArrowNativeTypeOp,
450+
{
451+
let left_primitive = left
452+
.as_any()
453+
.downcast_ref::<PrimitiveArray<T>>()
454+
.ok_or_else(|| Arc::clone(&right))?;
455+
456+
let right_primitive = right
457+
.as_any()
458+
.downcast_ref::<PrimitiveArray<T>>()
459+
.ok_or_else(|| Arc::clone(&right))?
460+
.clone();
461+
drop(right);
462+
463+
let mut builder = match right_primitive.into_builder() {
464+
Ok(b) => b,
465+
Err(arr) => return Err(Arc::new(arr)),
466+
};
467+
468+
type BinFn<N> = fn(N, N) -> N;
469+
let op_fn: Option<BinFn<T::Native>> = match op {
470+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
471+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
472+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
473+
_ => None,
474+
};
475+
476+
let Some(op_fn) = op_fn else {
477+
return Err(Arc::new(builder.finish()));
478+
};
479+
480+
let right_slice = builder.values_slice_mut();
481+
let left_values = left_primitive.values();
482+
483+
// Note: op(left[i], right[i]) — left is the first operand
484+
right_slice
485+
.iter_mut()
486+
.zip(left_values.iter())
487+
.for_each(|(r, l)| *r = op_fn(*l, *r));
488+
489+
// Merge null buffers from both sides
490+
let result = builder.finish();
491+
if left_primitive.nulls().is_some() {
492+
let merged = NullBuffer::union(result.nulls(), left_primitive.nulls());
493+
let result = PrimitiveArray::<T>::new(result.values().clone(), merged);
494+
Ok(Arc::new(result))
495+
} else {
496+
Ok(Arc::new(result))
497+
}
498+
}
499+
311500
/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
312501
pub fn apply_cmp(
313502
lhs: &ColumnarValue,

0 commit comments

Comments
 (0)