Skip to content

Commit 29c5dd5

Browse files
shivbhatia10Shiv Bhatiacomphead
authored
[datafusion-spark] Add Spark-compatible ceil function (#20593)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> Part of #15914 and apache/datafusion-comet#1704 I also just noticed #15916 butI believe work has been stale on this one. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Helping to continue adding Spark compatible expressions to datafusion-spark. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Add new `ceil` function. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes, unit tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Shiv Bhatia <sbhatia@palantir.com> Co-authored-by: Oleks V <comphead@users.noreply.github.com>
1 parent c253bfb commit 29c5dd5

File tree

4 files changed

+428
-46
lines changed

4 files changed

+428
-46
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{ArrowNativeTypeOp, AsArray, Decimal128Array};
21+
use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type};
22+
use datafusion_common::utils::take_function_args;
23+
use datafusion_common::{Result, ScalarValue, exec_err};
24+
use datafusion_expr::{
25+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26+
};
27+
28+
/// Spark-compatible `ceil` expression
29+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#ceil>
30+
///
31+
/// Differences with DataFusion ceil:
32+
/// - Spark's ceil returns Int64 for float inputs; DataFusion preserves
33+
/// the input type (Float32→Float32, Float64→Float64)
34+
/// - Spark's ceil on Decimal128(p, s) returns Decimal128(p−s+1, 0), reducing scale
35+
/// to 0; DataFusion preserves the original precision and scale
36+
/// - Spark only supports Decimal128; DataFusion also supports Decimal32/64/256
37+
/// - Spark does not check for decimal overflow; DataFusion errors on overflow
38+
///
39+
/// 2-argument ceil(value, scale) is not yet implemented
40+
/// <https://github.com/apache/datafusion/issues/21560>
41+
#[derive(Debug, PartialEq, Eq, Hash)]
42+
pub struct SparkCeil {
43+
signature: Signature,
44+
aliases: Vec<String>,
45+
}
46+
47+
impl Default for SparkCeil {
48+
fn default() -> Self {
49+
Self::new()
50+
}
51+
}
52+
53+
impl SparkCeil {
54+
pub fn new() -> Self {
55+
Self {
56+
signature: Signature::numeric(1, Volatility::Immutable),
57+
aliases: vec!["ceiling".to_string()],
58+
}
59+
}
60+
}
61+
62+
impl ScalarUDFImpl for SparkCeil {
63+
fn name(&self) -> &str {
64+
"ceil"
65+
}
66+
67+
fn signature(&self) -> &Signature {
68+
&self.signature
69+
}
70+
71+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
72+
match &arg_types[0] {
73+
DataType::Decimal128(p, s) => {
74+
if *s > 0 {
75+
Ok(DataType::Decimal128(decimal128_ceil_precision(*p, *s), 0))
76+
} else {
77+
// scale <= 0 means the value is already a whole number
78+
// (or represents multiples of 10^(-scale)), so ceil is a no-op
79+
Ok(DataType::Decimal128(*p, *s))
80+
}
81+
}
82+
dt if matches!(dt, DataType::Float32 | DataType::Float64)
83+
|| dt.is_integer() =>
84+
{
85+
Ok(DataType::Int64)
86+
}
87+
other => exec_err!("Unsupported data type {other:?} for function ceil"),
88+
}
89+
}
90+
91+
fn aliases(&self) -> &[String] {
92+
&self.aliases
93+
}
94+
95+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96+
spark_ceil(&args.args)
97+
}
98+
}
99+
100+
fn spark_ceil(args: &[ColumnarValue]) -> Result<ColumnarValue> {
101+
let [input] = take_function_args("ceil", args)?;
102+
103+
match input {
104+
ColumnarValue::Scalar(value) => spark_ceil_scalar(value),
105+
ColumnarValue::Array(input) => spark_ceil_array(input),
106+
}
107+
}
108+
109+
/// Compute ceil for a single decimal128 value with the given scale.
110+
#[inline]
111+
fn decimal128_ceil(value: i128, scale: u32) -> i128 {
112+
let div = 10_i128.pow_wrapping(scale);
113+
let d = value / div;
114+
let r = value % div;
115+
if r > 0 { d + 1 } else { d }
116+
}
117+
118+
/// Compute the return precision for a decimal128 ceil result.
119+
#[inline]
120+
fn decimal128_ceil_precision(precision: u8, scale: i8) -> u8 {
121+
((precision as i64) - (scale as i64) + 1).clamp(1, 38) as u8
122+
}
123+
124+
fn spark_ceil_scalar(value: &ScalarValue) -> Result<ColumnarValue> {
125+
let result = match value {
126+
ScalarValue::Float32(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)),
127+
ScalarValue::Float64(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)),
128+
v if v.data_type().is_integer() => v.cast_to(&DataType::Int64)?,
129+
ScalarValue::Decimal128(v, p, s) if *s > 0 => {
130+
let new_p = decimal128_ceil_precision(*p, *s);
131+
ScalarValue::Decimal128(v.map(|x| decimal128_ceil(x, *s as u32)), new_p, 0)
132+
}
133+
ScalarValue::Decimal128(_, _, _) => value.clone(),
134+
other => {
135+
return exec_err!(
136+
"Unsupported data type {:?} for function ceil",
137+
other.data_type()
138+
);
139+
}
140+
};
141+
Ok(ColumnarValue::Scalar(result))
142+
}
143+
144+
fn spark_ceil_array(input: &Arc<dyn arrow::array::Array>) -> Result<ColumnarValue> {
145+
let result = match input.data_type() {
146+
DataType::Float32 => Arc::new(
147+
input
148+
.as_primitive::<Float32Type>()
149+
.unary::<_, Int64Type>(|x| x.ceil() as i64),
150+
) as _,
151+
DataType::Float64 => Arc::new(
152+
input
153+
.as_primitive::<Float64Type>()
154+
.unary::<_, Int64Type>(|x| x.ceil() as i64),
155+
) as _,
156+
dt if dt.is_integer() => arrow::compute::cast(input, &DataType::Int64)?,
157+
DataType::Decimal128(p, s) if *s > 0 => {
158+
let new_p = decimal128_ceil_precision(*p, *s);
159+
let result: Decimal128Array = input
160+
.as_primitive::<Decimal128Type>()
161+
.unary(|x| decimal128_ceil(x, *s as u32));
162+
Arc::new(result.with_data_type(DataType::Decimal128(new_p, 0)))
163+
}
164+
DataType::Decimal128(_, _) => Arc::clone(input),
165+
other => return exec_err!("Unsupported data type {other:?} for function ceil"),
166+
};
167+
168+
Ok(ColumnarValue::Array(result))
169+
}
170+
171+
#[cfg(test)]
172+
mod tests {
173+
use super::*;
174+
use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array};
175+
use datafusion_common::ScalarValue;
176+
177+
#[test]
178+
fn test_ceil_float64() {
179+
let input = Float64Array::from(vec![
180+
Some(125.2345),
181+
Some(15.0001),
182+
Some(0.1),
183+
Some(-0.9),
184+
Some(-1.1),
185+
Some(123.0),
186+
None,
187+
]);
188+
let args = vec![ColumnarValue::Array(Arc::new(input))];
189+
let result = spark_ceil(&args).unwrap();
190+
let result = match result {
191+
ColumnarValue::Array(arr) => arr,
192+
_ => panic!("Expected array"),
193+
};
194+
let result = result.as_primitive::<Int64Type>();
195+
assert_eq!(
196+
result,
197+
&Int64Array::from(vec![
198+
Some(126),
199+
Some(16),
200+
Some(1),
201+
Some(0),
202+
Some(-1),
203+
Some(123),
204+
None,
205+
])
206+
);
207+
}
208+
209+
#[test]
210+
fn test_ceil_float32() {
211+
let input = Float32Array::from(vec![
212+
Some(125.2345f32),
213+
Some(15.0001f32),
214+
Some(0.1f32),
215+
Some(-0.9f32),
216+
Some(-1.1f32),
217+
Some(123.0f32),
218+
None,
219+
]);
220+
let args = vec![ColumnarValue::Array(Arc::new(input))];
221+
let result = spark_ceil(&args).unwrap();
222+
let result = match result {
223+
ColumnarValue::Array(arr) => arr,
224+
_ => panic!("Expected array"),
225+
};
226+
let result = result.as_primitive::<Int64Type>();
227+
assert_eq!(
228+
result,
229+
&Int64Array::from(vec![
230+
Some(126),
231+
Some(16),
232+
Some(1),
233+
Some(0),
234+
Some(-1),
235+
Some(123),
236+
None,
237+
])
238+
);
239+
}
240+
241+
#[test]
242+
fn test_ceil_int64() {
243+
let input = Int64Array::from(vec![Some(1), Some(-1), None]);
244+
let args = vec![ColumnarValue::Array(Arc::new(input))];
245+
let result = spark_ceil(&args).unwrap();
246+
let result = match result {
247+
ColumnarValue::Array(arr) => arr,
248+
_ => panic!("Expected array"),
249+
};
250+
let result = result.as_primitive::<Int64Type>();
251+
assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-1), None]));
252+
}
253+
254+
#[test]
255+
fn test_ceil_decimal128() {
256+
// Decimal128(10, 2): 150 = 1.50, -150 = -1.50, 100 = 1.00
257+
let return_type = DataType::Decimal128(9, 0);
258+
let input = Decimal128Array::from(vec![Some(150), Some(-150), Some(100), None])
259+
.with_data_type(DataType::Decimal128(10, 2));
260+
let args = vec![ColumnarValue::Array(Arc::new(input))];
261+
let result = spark_ceil(&args).unwrap();
262+
let result = match result {
263+
ColumnarValue::Array(arr) => arr,
264+
_ => panic!("Expected array"),
265+
};
266+
let result = result.as_primitive::<Decimal128Type>();
267+
let expected = Decimal128Array::from(vec![Some(2), Some(-1), Some(1), None])
268+
.with_data_type(return_type);
269+
assert_eq!(result, &expected);
270+
}
271+
272+
#[test]
273+
fn test_ceil_float64_scalar() {
274+
let input = ScalarValue::Float64(Some(-1.1));
275+
let args = vec![ColumnarValue::Scalar(input)];
276+
let result = match spark_ceil(&args).unwrap() {
277+
ColumnarValue::Scalar(v) => v,
278+
_ => panic!("Expected scalar"),
279+
};
280+
assert_eq!(result, ScalarValue::Int64(Some(-1)));
281+
}
282+
283+
#[test]
284+
fn test_ceil_float32_scalar() {
285+
let input = ScalarValue::Float32(Some(125.2345f32));
286+
let args = vec![ColumnarValue::Scalar(input)];
287+
let result = match spark_ceil(&args).unwrap() {
288+
ColumnarValue::Scalar(v) => v,
289+
_ => panic!("Expected scalar"),
290+
};
291+
assert_eq!(result, ScalarValue::Int64(Some(126)));
292+
}
293+
294+
#[test]
295+
fn test_ceil_int64_scalar() {
296+
let input = ScalarValue::Int64(Some(48));
297+
let args = vec![ColumnarValue::Scalar(input)];
298+
let result = match spark_ceil(&args).unwrap() {
299+
ColumnarValue::Scalar(v) => v,
300+
_ => panic!("Expected scalar"),
301+
};
302+
assert_eq!(result, ScalarValue::Int64(Some(48)));
303+
}
304+
}

datafusion/spark/src/function/math/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
pub mod abs;
1919
pub mod bin;
20+
pub mod ceil;
2021
pub mod expm1;
2122
pub mod factorial;
2223
pub mod hex;
@@ -33,6 +34,7 @@ use datafusion_functions::make_udf_function;
3334
use std::sync::Arc;
3435

3536
make_udf_function!(abs::SparkAbs, abs);
37+
make_udf_function!(ceil::SparkCeil, ceil);
3638
make_udf_function!(expm1::SparkExpm1, expm1);
3739
make_udf_function!(factorial::SparkFactorial, factorial);
3840
make_udf_function!(hex::SparkHex, hex);
@@ -51,6 +53,7 @@ pub mod expr_fn {
5153
use datafusion_functions::export_functions;
5254

5355
export_functions!((abs, "Returns abs(expr)", arg1));
56+
export_functions!((ceil, "Returns the ceiling of expr.", arg1));
5457
export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1));
5558
export_functions!((
5659
factorial,
@@ -89,6 +92,7 @@ pub mod expr_fn {
8992
pub fn functions() -> Vec<Arc<ScalarUDF>> {
9093
vec![
9194
abs(),
95+
ceil(),
9296
expm1(),
9397
factorial(),
9498
hex(),

0 commit comments

Comments
 (0)