Skip to content

Commit d544735

Browse files
SubhamSinghalNisha AgrawalJefffreySubham Singhalcomphead
authored
feat(spark): Adds negative spark function (#20006)
## What changes are included in this PR? Adds support for negative spark function in data fusion. ## Are these changes tested? yes, using UTs ## Are there any user-facing changes? yes, adds new function. --------- Co-authored-by: Nisha Agrawal <nishaagrawal@Nishas-MacBook-Air.local> Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com> Co-authored-by: Subham Singhal <subhamsinghal@Nishas-MacBook-Air.local> Co-authored-by: Oleks V <comphead@users.noreply.github.com> Co-authored-by: Martin Grigorov <martin-g@users.noreply.github.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 639971a commit d544735

3 files changed

Lines changed: 555 additions & 2 deletions

File tree

datafusion/spark/src/function/math/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub mod expm1;
2020
pub mod factorial;
2121
pub mod hex;
2222
pub mod modulus;
23+
pub mod negative;
2324
pub mod rint;
2425
pub mod trigonometry;
2526
pub mod unhex;
@@ -40,6 +41,7 @@ make_udf_function!(unhex::SparkUnhex, unhex);
4041
make_udf_function!(width_bucket::SparkWidthBucket, width_bucket);
4142
make_udf_function!(trigonometry::SparkCsc, csc);
4243
make_udf_function!(trigonometry::SparkSec, sec);
44+
make_udf_function!(negative::SparkNegative, negative);
4345

4446
pub mod expr_fn {
4547
use datafusion_functions::export_functions;
@@ -63,6 +65,11 @@ pub mod expr_fn {
6365
export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4));
6466
export_functions!((csc, "Returns the cosecant of expr.", arg1));
6567
export_functions!((sec, "Returns the secant of expr.", arg1));
68+
export_functions!((
69+
negative,
70+
"Returns the negation of expr (unary minus).",
71+
arg1
72+
));
6673
}
6774

6875
pub fn functions() -> Vec<Arc<ScalarUDF>> {
@@ -78,5 +85,6 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
7885
width_bucket(),
7986
csc(),
8087
sec(),
88+
negative(),
8189
]
8290
}
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::types::*;
19+
use arrow::array::*;
20+
use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit};
21+
use bigdecimal::num_traits::WrappingNeg;
22+
use datafusion_common::utils::take_function_args;
23+
use datafusion_common::{Result, ScalarValue, not_impl_err};
24+
use datafusion_expr::{
25+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
26+
Volatility,
27+
};
28+
use std::any::Any;
29+
use std::sync::Arc;
30+
31+
/// Spark-compatible `negative` expression
32+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#negative>
33+
///
34+
/// Returns the negation of input (equivalent to unary minus)
35+
/// Returns NULL if input is NULL, returns NaN if input is NaN.
36+
///
37+
/// ANSI mode support see (<https://github.com/apache/datafusion/issues/20034>):
38+
/// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`),
39+
/// negating the minimal value of a signed integer wraps around.
40+
/// For example: negative(i32::MIN) returns i32::MIN (wraps instead of error).
41+
/// This is the current implementation (legacy mode only).
42+
/// - Spark's ANSI mode (when `spark.sql.ansi.enabled=true`) should throw an
43+
/// ARITHMETIC_OVERFLOW error on integer overflow instead of wrapping.
44+
/// This is not yet implemented - all operations currently use wrapping behavior.
45+
///
46+
#[derive(Debug, PartialEq, Eq, Hash)]
47+
pub struct SparkNegative {
48+
signature: Signature,
49+
}
50+
51+
impl Default for SparkNegative {
52+
fn default() -> Self {
53+
Self::new()
54+
}
55+
}
56+
57+
impl SparkNegative {
58+
pub fn new() -> Self {
59+
Self {
60+
signature: Signature {
61+
type_signature: TypeSignature::OneOf(vec![
62+
// Numeric types: signed integers, float, decimals
63+
TypeSignature::Numeric(1),
64+
// Interval types: YearMonth, DayTime, MonthDayNano
65+
TypeSignature::Uniform(
66+
1,
67+
vec![
68+
DataType::Interval(IntervalUnit::YearMonth),
69+
DataType::Interval(IntervalUnit::DayTime),
70+
DataType::Interval(IntervalUnit::MonthDayNano),
71+
],
72+
),
73+
]),
74+
volatility: Volatility::Immutable,
75+
parameter_names: None,
76+
},
77+
}
78+
}
79+
}
80+
81+
impl ScalarUDFImpl for SparkNegative {
82+
fn as_any(&self) -> &dyn Any {
83+
self
84+
}
85+
86+
fn name(&self) -> &str {
87+
"negative"
88+
}
89+
90+
fn signature(&self) -> &Signature {
91+
&self.signature
92+
}
93+
94+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95+
Ok(arg_types[0].clone())
96+
}
97+
98+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99+
spark_negative(&args.args)
100+
}
101+
}
102+
103+
/// Core implementation of Spark's negative function
104+
fn spark_negative(args: &[ColumnarValue]) -> Result<ColumnarValue> {
105+
let [arg] = take_function_args("negative", args)?;
106+
107+
match arg {
108+
ColumnarValue::Array(array) => match array.data_type() {
109+
DataType::Null => Ok(arg.clone()),
110+
111+
// Signed integers - use wrapping negation (Spark legacy mode behavior)
112+
DataType::Int8 => {
113+
let array = array.as_primitive::<Int8Type>();
114+
let result: PrimitiveArray<Int8Type> = array.unary(|x| x.wrapping_neg());
115+
Ok(ColumnarValue::Array(Arc::new(result)))
116+
}
117+
DataType::Int16 => {
118+
let array = array.as_primitive::<Int16Type>();
119+
let result: PrimitiveArray<Int16Type> = array.unary(|x| x.wrapping_neg());
120+
Ok(ColumnarValue::Array(Arc::new(result)))
121+
}
122+
DataType::Int32 => {
123+
let array = array.as_primitive::<Int32Type>();
124+
let result: PrimitiveArray<Int32Type> = array.unary(|x| x.wrapping_neg());
125+
Ok(ColumnarValue::Array(Arc::new(result)))
126+
}
127+
DataType::Int64 => {
128+
let array = array.as_primitive::<Int64Type>();
129+
let result: PrimitiveArray<Int64Type> = array.unary(|x| x.wrapping_neg());
130+
Ok(ColumnarValue::Array(Arc::new(result)))
131+
}
132+
133+
// Floating point - simple negation (no overflow possible)
134+
DataType::Float16 => {
135+
let array = array.as_primitive::<Float16Type>();
136+
let result: PrimitiveArray<Float16Type> = array.unary(|x| -x);
137+
Ok(ColumnarValue::Array(Arc::new(result)))
138+
}
139+
DataType::Float32 => {
140+
let array = array.as_primitive::<Float32Type>();
141+
let result: PrimitiveArray<Float32Type> = array.unary(|x| -x);
142+
Ok(ColumnarValue::Array(Arc::new(result)))
143+
}
144+
DataType::Float64 => {
145+
let array = array.as_primitive::<Float64Type>();
146+
let result: PrimitiveArray<Float64Type> = array.unary(|x| -x);
147+
Ok(ColumnarValue::Array(Arc::new(result)))
148+
}
149+
150+
// Decimal types - wrapping negation
151+
DataType::Decimal32(_, _) => {
152+
let array = array.as_primitive::<Decimal32Type>();
153+
let result: PrimitiveArray<Decimal32Type> =
154+
array.unary(|x| x.wrapping_neg());
155+
Ok(ColumnarValue::Array(Arc::new(result)))
156+
}
157+
DataType::Decimal64(_, _) => {
158+
let array = array.as_primitive::<Decimal64Type>();
159+
let result: PrimitiveArray<Decimal64Type> =
160+
array.unary(|x| x.wrapping_neg());
161+
Ok(ColumnarValue::Array(Arc::new(result)))
162+
}
163+
DataType::Decimal128(_, _) => {
164+
let array = array.as_primitive::<Decimal128Type>();
165+
let result: PrimitiveArray<Decimal128Type> =
166+
array.unary(|x| x.wrapping_neg());
167+
Ok(ColumnarValue::Array(Arc::new(result)))
168+
}
169+
DataType::Decimal256(_, _) => {
170+
let array = array.as_primitive::<Decimal256Type>();
171+
let result: PrimitiveArray<Decimal256Type> =
172+
array.unary(|x| x.wrapping_neg());
173+
Ok(ColumnarValue::Array(Arc::new(result)))
174+
}
175+
176+
// interval type
177+
DataType::Interval(IntervalUnit::YearMonth) => {
178+
let array = array.as_primitive::<IntervalYearMonthType>();
179+
let result: PrimitiveArray<IntervalYearMonthType> =
180+
array.unary(|x| x.wrapping_neg());
181+
Ok(ColumnarValue::Array(Arc::new(result)))
182+
}
183+
DataType::Interval(IntervalUnit::DayTime) => {
184+
let array = array.as_primitive::<IntervalDayTimeType>();
185+
let result: PrimitiveArray<IntervalDayTimeType> =
186+
array.unary(|x| IntervalDayTime {
187+
days: x.days.wrapping_neg(),
188+
milliseconds: x.milliseconds.wrapping_neg(),
189+
});
190+
Ok(ColumnarValue::Array(Arc::new(result)))
191+
}
192+
DataType::Interval(IntervalUnit::MonthDayNano) => {
193+
let array = array.as_primitive::<IntervalMonthDayNanoType>();
194+
let result: PrimitiveArray<IntervalMonthDayNanoType> =
195+
array.unary(|x| IntervalMonthDayNano {
196+
months: x.months.wrapping_neg(),
197+
days: x.days.wrapping_neg(),
198+
nanoseconds: x.nanoseconds.wrapping_neg(),
199+
});
200+
Ok(ColumnarValue::Array(Arc::new(result)))
201+
}
202+
203+
dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
204+
},
205+
ColumnarValue::Scalar(sv) => match sv {
206+
ScalarValue::Null => Ok(arg.clone()),
207+
_ if sv.is_null() => Ok(arg.clone()),
208+
209+
// Signed integers - wrapping negation
210+
ScalarValue::Int8(Some(v)) => {
211+
let result = v.wrapping_neg();
212+
Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result))))
213+
}
214+
ScalarValue::Int16(Some(v)) => {
215+
let result = v.wrapping_neg();
216+
Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result))))
217+
}
218+
ScalarValue::Int32(Some(v)) => {
219+
let result = v.wrapping_neg();
220+
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result))))
221+
}
222+
ScalarValue::Int64(Some(v)) => {
223+
let result = v.wrapping_neg();
224+
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result))))
225+
}
226+
227+
// Floating point - simple negation
228+
ScalarValue::Float16(Some(v)) => {
229+
Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(-v))))
230+
}
231+
ScalarValue::Float32(Some(v)) => {
232+
Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(-v))))
233+
}
234+
ScalarValue::Float64(Some(v)) => {
235+
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(-v))))
236+
}
237+
238+
// Decimal types - wrapping negation
239+
ScalarValue::Decimal32(Some(v), precision, scale) => {
240+
let result = v.wrapping_neg();
241+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal32(
242+
Some(result),
243+
*precision,
244+
*scale,
245+
)))
246+
}
247+
ScalarValue::Decimal64(Some(v), precision, scale) => {
248+
let result = v.wrapping_neg();
249+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal64(
250+
Some(result),
251+
*precision,
252+
*scale,
253+
)))
254+
}
255+
ScalarValue::Decimal128(Some(v), precision, scale) => {
256+
let result = v.wrapping_neg();
257+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
258+
Some(result),
259+
*precision,
260+
*scale,
261+
)))
262+
}
263+
ScalarValue::Decimal256(Some(v), precision, scale) => {
264+
let result = v.wrapping_neg();
265+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(
266+
Some(result),
267+
*precision,
268+
*scale,
269+
)))
270+
}
271+
272+
//interval type
273+
ScalarValue::IntervalYearMonth(Some(v)) => Ok(ColumnarValue::Scalar(
274+
ScalarValue::IntervalYearMonth(Some(v.wrapping_neg())),
275+
)),
276+
ScalarValue::IntervalDayTime(Some(v)) => Ok(ColumnarValue::Scalar(
277+
ScalarValue::IntervalDayTime(Some(IntervalDayTime {
278+
days: v.days.wrapping_neg(),
279+
milliseconds: v.milliseconds.wrapping_neg(),
280+
})),
281+
)),
282+
ScalarValue::IntervalMonthDayNano(Some(v)) => Ok(ColumnarValue::Scalar(
283+
ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
284+
months: v.months.wrapping_neg(),
285+
days: v.days.wrapping_neg(),
286+
nanoseconds: v.nanoseconds.wrapping_neg(),
287+
})),
288+
)),
289+
290+
dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"),
291+
},
292+
}
293+
}

0 commit comments

Comments
 (0)