Skip to content

Commit 418ccdd

Browse files
Dandandanclaude
andcommitted
Also try in-place mutation on the right operand
When the left operand's buffer cannot be reused (shared reference), try the right operand for in-place mutation. This covers cases like Scalar-Array and Array-Array where the right side has refcount 1. For non-commutative ops (sub, rem), the argument order is preserved correctly: result[i] = op(left[i], right[i]). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e73a965 commit 418ccdd

1 file changed

Lines changed: 192 additions & 3 deletions

File tree

  • datafusion/physical-expr-common/src

datafusion/physical-expr-common/src/datum.rs

Lines changed: 192 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,21 @@ pub fn apply_arithmetic(
8989
let f = arithmetic_op_to_fn(op);
9090
match (lhs, rhs) {
9191
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
92-
// Try in-place on left array with right array values
92+
// Try in-place on left array, then right array
9393
match try_apply_inplace_array(left, &right, op) {
9494
Ok(result) => Ok(ColumnarValue::Array(result)),
95-
Err(left) => Ok(ColumnarValue::Array(f(&left, &right)?)),
95+
Err(left) => match try_apply_inplace_array_rhs(&left, right, op) {
96+
Ok(result) => Ok(ColumnarValue::Array(result)),
97+
Err(right) => Ok(ColumnarValue::Array(f(&left, &right)?)),
98+
},
9699
}
97100
}
98101
(ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => {
99-
Ok(ColumnarValue::Array(f(&left.to_scalar()?, &right)?))
102+
// Try in-place on right array with scalar left (flipped)
103+
match try_apply_inplace_scalar_rhs(right, &left, op) {
104+
Ok(result) => Ok(ColumnarValue::Array(result)),
105+
Err(right) => Ok(ColumnarValue::Array(f(&left.to_scalar()?, &right)?)),
106+
}
100107
}
101108
(ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => {
102109
// Try in-place on left array with scalar right
@@ -310,6 +317,188 @@ where
310317
}
311318
}
312319

320+
/// Try to apply arithmetic in-place on `right` array with a scalar `left`.
321+
/// The operation is `result[i] = op(scalar_left, right[i])`, stored in `right`'s buffer.
322+
/// Returns `Ok(result)` on success, or `Err(right)` if in-place not possible.
323+
fn try_apply_inplace_scalar_rhs(
324+
right: ArrayRef,
325+
left: &ScalarValue,
326+
op: ArithmeticOp,
327+
) -> Result<ArrayRef, ArrayRef> {
328+
if left.is_null() {
329+
return Err(right);
330+
}
331+
macro_rules! dispatch_inplace_scalar_rhs {
332+
($($arrow_type:ident),*) => {
333+
match right.data_type() {
334+
$(
335+
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
336+
let scalar_val = left
337+
.to_scalar()
338+
.map_err(|_| Arc::clone(&right))?;
339+
let scalar_arr = scalar_val.get().0;
340+
let lhs_val = scalar_arr
341+
.as_any()
342+
.downcast_ref::<PrimitiveArray<$arrow_type>>()
343+
.ok_or_else(|| Arc::clone(&right))?
344+
.value(0);
345+
try_inplace_unary_rhs::<$arrow_type>(right, lhs_val, op)
346+
}
347+
)*
348+
_ => Err(right),
349+
}
350+
};
351+
}
352+
dispatch_inplace_scalar_rhs!(
353+
Int8Type,
354+
Int16Type,
355+
Int32Type,
356+
Int64Type,
357+
UInt8Type,
358+
UInt16Type,
359+
UInt32Type,
360+
UInt64Type,
361+
Float16Type,
362+
Float32Type,
363+
Float64Type
364+
)
365+
}
366+
367+
/// Try to apply arithmetic in-place on `right` array using values from `left` array.
368+
/// The operation is `result[i] = op(left[i], right[i])`, stored in `right`'s buffer.
369+
/// Returns `Ok(result)` on success, or `Err(right)` if in-place not possible.
370+
fn try_apply_inplace_array_rhs(
371+
left: &ArrayRef,
372+
right: ArrayRef,
373+
op: ArithmeticOp,
374+
) -> Result<ArrayRef, ArrayRef> {
375+
if left.data_type() != right.data_type() {
376+
return Err(right);
377+
}
378+
macro_rules! dispatch_inplace_array_rhs {
379+
($($arrow_type:ident),*) => {
380+
match right.data_type() {
381+
$(
382+
dt if dt == &<$arrow_type as ArrowPrimitiveType>::DATA_TYPE => {
383+
try_inplace_binary_rhs::<$arrow_type>(left, right, op)
384+
}
385+
)*
386+
_ => Err(right),
387+
}
388+
};
389+
}
390+
dispatch_inplace_array_rhs!(
391+
Int8Type,
392+
Int16Type,
393+
Int32Type,
394+
Int64Type,
395+
UInt8Type,
396+
UInt16Type,
397+
UInt32Type,
398+
UInt64Type,
399+
Float16Type,
400+
Float32Type,
401+
Float64Type
402+
)
403+
}
404+
405+
/// Attempt in-place mutation on the right PrimitiveArray: result[i] = op(scalar, right[i]).
406+
fn try_inplace_unary_rhs<T: ArrowPrimitiveType>(
407+
array: ArrayRef,
408+
scalar: T::Native,
409+
op: ArithmeticOp,
410+
) -> Result<ArrayRef, ArrayRef>
411+
where
412+
T::Native: ArrowNativeTypeOp,
413+
{
414+
let primitive = array
415+
.as_any()
416+
.downcast_ref::<PrimitiveArray<T>>()
417+
.ok_or_else(|| Arc::clone(&array))?
418+
.clone();
419+
drop(array);
420+
421+
// For right-side mutation: result = op(scalar, element)
422+
// Commutative ops: same as op(element, scalar)
423+
// Non-commutative: need reversed argument order
424+
type BinFn<N> = fn(N, N) -> N;
425+
let op_fn: Option<BinFn<T::Native>> = match op {
426+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
427+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
428+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
429+
ArithmeticOp::Rem if !scalar.is_zero() => Some(ArrowNativeTypeOp::mod_wrapping),
430+
_ => None,
431+
};
432+
433+
let Some(op_fn) = op_fn else {
434+
return Err(Arc::new(primitive));
435+
};
436+
437+
// Note: op(scalar, v) — scalar is the left operand
438+
match primitive.unary_mut(|v| op_fn(scalar, v)) {
439+
Ok(result) => Ok(Arc::new(result)),
440+
Err(arr) => Err(Arc::new(arr)),
441+
}
442+
}
443+
444+
/// Attempt in-place mutation on the right PrimitiveArray: result[i] = op(left[i], right[i]).
445+
fn try_inplace_binary_rhs<T: ArrowPrimitiveType>(
446+
left: &ArrayRef,
447+
right: ArrayRef,
448+
op: ArithmeticOp,
449+
) -> Result<ArrayRef, ArrayRef>
450+
where
451+
T::Native: ArrowNativeTypeOp,
452+
{
453+
let left_primitive = left
454+
.as_any()
455+
.downcast_ref::<PrimitiveArray<T>>()
456+
.ok_or_else(|| Arc::clone(&right))?;
457+
458+
let right_primitive = right
459+
.as_any()
460+
.downcast_ref::<PrimitiveArray<T>>()
461+
.ok_or_else(|| Arc::clone(&right))?
462+
.clone();
463+
drop(right);
464+
465+
let mut builder = match right_primitive.into_builder() {
466+
Ok(b) => b,
467+
Err(arr) => return Err(Arc::new(arr)),
468+
};
469+
470+
type BinFn<N> = fn(N, N) -> N;
471+
let op_fn: Option<BinFn<T::Native>> = match op {
472+
ArithmeticOp::AddWrapping => Some(ArrowNativeTypeOp::add_wrapping),
473+
ArithmeticOp::SubWrapping => Some(ArrowNativeTypeOp::sub_wrapping),
474+
ArithmeticOp::MulWrapping => Some(ArrowNativeTypeOp::mul_wrapping),
475+
_ => None,
476+
};
477+
478+
let Some(op_fn) = op_fn else {
479+
return Err(Arc::new(builder.finish()));
480+
};
481+
482+
let right_slice = builder.values_slice_mut();
483+
let left_values = left_primitive.values();
484+
485+
// Note: op(left[i], right[i]) — left is the first operand
486+
right_slice
487+
.iter_mut()
488+
.zip(left_values.iter())
489+
.for_each(|(r, l)| *r = op_fn(*l, *r));
490+
491+
// Merge null buffers from both sides
492+
let result = builder.finish();
493+
if left_primitive.nulls().is_some() {
494+
let merged = NullBuffer::union(result.nulls(), left_primitive.nulls());
495+
let result = PrimitiveArray::<T>::new(result.values().clone(), merged);
496+
Ok(Arc::new(result))
497+
} else {
498+
Ok(Arc::new(result))
499+
}
500+
}
501+
313502
/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs`
314503
pub fn apply_cmp(
315504
op: Operator,

0 commit comments

Comments
 (0)