Skip to content

Commit c849374

Browse files
authored
Refactor iszero() and isnan() to accept all numeric types (#20093)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #20089 ## Rationale for this change iszero() and isnan() previously accepted “numeric” inputs by implicitly coercing them to Float64, adding unnecessary casts and work for integer/decimal inputs. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Updated iszero() and isnan() signatures to accept TypeSignatureClass::Numeric without implicit float coercion. - Refactored iszero() implementation to evaluate zero checks directly for integers, unsigned integers, floats, and decimals. - Refactored isnan() implementation to compute is_nan only for float types and return false for all other numeric types and reduced non-float array handling to use unary kernels instead of manual iteration. - Added sqllogictest <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? No <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 51c0475 commit c849374

3 files changed

Lines changed: 278 additions & 51 deletions

File tree

datafusion/functions/src/math/iszero.rs

Lines changed: 114 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@ use std::any::Any;
1919
use std::sync::Arc;
2020

2121
use arrow::array::{ArrowNativeTypeOp, AsArray, BooleanArray};
22-
use arrow::datatypes::DataType::{Boolean, Float16, Float32, Float64};
23-
use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type};
22+
use arrow::datatypes::DataType::{
23+
Boolean, Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64,
24+
Int8, Int16, Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64,
25+
};
26+
use arrow::datatypes::{
27+
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type,
28+
Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
29+
UInt16Type, UInt32Type, UInt64Type,
30+
};
2431

25-
use datafusion_common::types::NativeType;
2632
use datafusion_common::utils::take_function_args;
2733
use datafusion_common::{Result, ScalarValue, internal_err};
2834
use datafusion_expr::{Coercion, TypeSignatureClass};
@@ -59,14 +65,10 @@ impl Default for IsZeroFunc {
5965

6066
impl IsZeroFunc {
6167
pub fn new() -> Self {
62-
// Accept any numeric type and coerce to float
63-
let float = Coercion::new_implicit(
64-
TypeSignatureClass::Float,
65-
vec![TypeSignatureClass::Numeric],
66-
NativeType::Float64,
67-
);
68+
// Accept any numeric type (ints, uints, floats, decimals) without implicit casts.
69+
let numeric = Coercion::new_exact(TypeSignatureClass::Numeric);
6870
Self {
69-
signature: Signature::coercible(vec![float], Volatility::Immutable),
71+
signature: Signature::coercible(vec![numeric], Volatility::Immutable),
7072
}
7173
}
7274
}
@@ -107,6 +109,45 @@ impl ScalarUDFImpl for IsZeroFunc {
107109
ScalarValue::Float16(Some(v)) => Ok(ColumnarValue::Scalar(
108110
ScalarValue::Boolean(Some(v.is_zero())),
109111
)),
112+
113+
ScalarValue::Int8(Some(v)) => {
114+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
115+
}
116+
ScalarValue::Int16(Some(v)) => {
117+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
118+
}
119+
ScalarValue::Int32(Some(v)) => {
120+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
121+
}
122+
ScalarValue::Int64(Some(v)) => {
123+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
124+
}
125+
ScalarValue::UInt8(Some(v)) => {
126+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
127+
}
128+
ScalarValue::UInt16(Some(v)) => {
129+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
130+
}
131+
ScalarValue::UInt32(Some(v)) => {
132+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
133+
}
134+
ScalarValue::UInt64(Some(v)) => {
135+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
136+
}
137+
138+
ScalarValue::Decimal32(Some(v), ..) => {
139+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
140+
}
141+
ScalarValue::Decimal64(Some(v), ..) => {
142+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
143+
}
144+
ScalarValue::Decimal128(Some(v), ..) => {
145+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0))))
146+
}
147+
ScalarValue::Decimal256(Some(v), ..) => Ok(ColumnarValue::Scalar(
148+
ScalarValue::Boolean(Some(v.is_zero())),
149+
)),
150+
110151
_ => {
111152
internal_err!(
112153
"Unexpected scalar type for iszero: {:?}",
@@ -116,6 +157,10 @@ impl ScalarUDFImpl for IsZeroFunc {
116157
}
117158
}
118159
ColumnarValue::Array(array) => match array.data_type() {
160+
Null => Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null(
161+
array.len(),
162+
)))),
163+
119164
Float64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
120165
array.as_primitive::<Float64Type>(),
121166
|x| x == 0.0,
@@ -128,6 +173,65 @@ impl ScalarUDFImpl for IsZeroFunc {
128173
array.as_primitive::<Float16Type>(),
129174
|x| x.is_zero(),
130175
)))),
176+
177+
Int8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
178+
array.as_primitive::<Int8Type>(),
179+
|x| x == 0,
180+
)))),
181+
Int16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
182+
array.as_primitive::<Int16Type>(),
183+
|x| x == 0,
184+
)))),
185+
Int32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
186+
array.as_primitive::<Int32Type>(),
187+
|x| x == 0,
188+
)))),
189+
Int64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
190+
array.as_primitive::<Int64Type>(),
191+
|x| x == 0,
192+
)))),
193+
UInt8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
194+
array.as_primitive::<UInt8Type>(),
195+
|x| x == 0,
196+
)))),
197+
UInt16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
198+
array.as_primitive::<UInt16Type>(),
199+
|x| x == 0,
200+
)))),
201+
UInt32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
202+
array.as_primitive::<UInt32Type>(),
203+
|x| x == 0,
204+
)))),
205+
UInt64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
206+
array.as_primitive::<UInt64Type>(),
207+
|x| x == 0,
208+
)))),
209+
210+
Decimal32(_, _) => {
211+
Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
212+
array.as_primitive::<Decimal32Type>(),
213+
|x| x == 0,
214+
))))
215+
}
216+
Decimal64(_, _) => {
217+
Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
218+
array.as_primitive::<Decimal64Type>(),
219+
|x| x == 0,
220+
))))
221+
}
222+
Decimal128(_, _) => {
223+
Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
224+
array.as_primitive::<Decimal128Type>(),
225+
|x| x == 0,
226+
))))
227+
}
228+
Decimal256(_, _) => {
229+
Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary(
230+
array.as_primitive::<Decimal256Type>(),
231+
|x| x.is_zero(),
232+
))))
233+
}
234+
131235
other => {
132236
internal_err!("Unexpected data type {other:?} for function iszero")
133237
}

datafusion/functions/src/math/nans.rs

Lines changed: 132 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,21 @@
1717

1818
//! Math function: `isnan()`.
1919
20-
use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type};
21-
use datafusion_common::types::NativeType;
22-
use datafusion_common::{Result, ScalarValue, exec_err};
23-
use datafusion_expr::{Coercion, ColumnarValue, ScalarFunctionArgs, TypeSignatureClass};
24-
2520
use arrow::array::{ArrayRef, AsArray, BooleanArray};
26-
use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility};
21+
use arrow::datatypes::DataType::{
22+
Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64, Int8, Int16,
23+
Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64,
24+
};
25+
use arrow::datatypes::{
26+
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type,
27+
Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
28+
UInt16Type, UInt32Type, UInt64Type,
29+
};
30+
use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args};
31+
use datafusion_expr::{
32+
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33+
TypeSignatureClass, Volatility,
34+
};
2735
use datafusion_macros::user_doc;
2836
use std::any::Any;
2937
use std::sync::Arc;
@@ -55,14 +63,10 @@ impl Default for IsNanFunc {
5563

5664
impl IsNanFunc {
5765
pub fn new() -> Self {
58-
// Accept any numeric type and coerce to float
59-
let float = Coercion::new_implicit(
60-
TypeSignatureClass::Float,
61-
vec![TypeSignatureClass::Numeric],
62-
NativeType::Float64,
63-
);
66+
// Accept any numeric type (ints, uints, floats, decimals) without implicit casts.
67+
let numeric = Coercion::new_exact(TypeSignatureClass::Numeric);
6468
Self {
65-
signature: Signature::coercible(vec![float], Volatility::Immutable),
69+
signature: Signature::coercible(vec![numeric], Volatility::Immutable),
6670
}
6771
}
6872
}
@@ -84,36 +88,123 @@ impl ScalarUDFImpl for IsNanFunc {
8488
}
8589

8690
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
87-
// Handle NULL input
88-
if args.args[0].data_type().is_null() {
89-
return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
90-
}
91+
let [arg] = take_function_args(self.name(), args.args)?;
92+
93+
match arg {
94+
ColumnarValue::Scalar(scalar) => {
95+
if scalar.is_null() {
96+
return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
97+
}
98+
99+
let result = match scalar {
100+
ScalarValue::Float64(Some(v)) => Some(v.is_nan()),
101+
ScalarValue::Float32(Some(v)) => Some(v.is_nan()),
102+
ScalarValue::Float16(Some(v)) => Some(v.is_nan()),
91103

92-
let args = ColumnarValue::values_to_arrays(&args.args)?;
93-
94-
let arr: ArrayRef = match args[0].data_type() {
95-
DataType::Float64 => Arc::new(BooleanArray::from_unary(
96-
args[0].as_primitive::<Float64Type>(),
97-
f64::is_nan,
98-
)) as ArrayRef,
99-
100-
DataType::Float32 => Arc::new(BooleanArray::from_unary(
101-
args[0].as_primitive::<Float32Type>(),
102-
f32::is_nan,
103-
)) as ArrayRef,
104-
105-
DataType::Float16 => Arc::new(BooleanArray::from_unary(
106-
args[0].as_primitive::<Float16Type>(),
107-
|x| x.is_nan(),
108-
)) as ArrayRef,
109-
other => {
110-
return exec_err!(
111-
"Unsupported data type {other:?} for function {}",
112-
self.name()
113-
);
104+
// Non-float numeric inputs are never NaN
105+
ScalarValue::Int8(_)
106+
| ScalarValue::Int16(_)
107+
| ScalarValue::Int32(_)
108+
| ScalarValue::Int64(_)
109+
| ScalarValue::UInt8(_)
110+
| ScalarValue::UInt16(_)
111+
| ScalarValue::UInt32(_)
112+
| ScalarValue::UInt64(_)
113+
| ScalarValue::Decimal32(_, _, _)
114+
| ScalarValue::Decimal64(_, _, _)
115+
| ScalarValue::Decimal128(_, _, _)
116+
| ScalarValue::Decimal256(_, _, _) => Some(false),
117+
118+
other => {
119+
return exec_err!(
120+
"Unsupported data type {other:?} for function {}",
121+
self.name()
122+
);
123+
}
124+
};
125+
126+
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)))
114127
}
115-
};
116-
Ok(ColumnarValue::Array(arr))
128+
ColumnarValue::Array(array) => {
129+
// NOTE: BooleanArray::from_unary preserves nulls.
130+
let arr: ArrayRef = match array.data_type() {
131+
Null => Arc::new(BooleanArray::new_null(array.len())) as ArrayRef,
132+
133+
Float64 => Arc::new(BooleanArray::from_unary(
134+
array.as_primitive::<Float64Type>(),
135+
f64::is_nan,
136+
)) as ArrayRef,
137+
Float32 => Arc::new(BooleanArray::from_unary(
138+
array.as_primitive::<Float32Type>(),
139+
f32::is_nan,
140+
)) as ArrayRef,
141+
Float16 => Arc::new(BooleanArray::from_unary(
142+
array.as_primitive::<Float16Type>(),
143+
|x| x.is_nan(),
144+
)) as ArrayRef,
145+
146+
// Non-float numeric arrays are never NaN
147+
Decimal32(_, _) => Arc::new(BooleanArray::from_unary(
148+
array.as_primitive::<Decimal32Type>(),
149+
|_| false,
150+
)) as ArrayRef,
151+
Decimal64(_, _) => Arc::new(BooleanArray::from_unary(
152+
array.as_primitive::<Decimal64Type>(),
153+
|_| false,
154+
)) as ArrayRef,
155+
Decimal128(_, _) => Arc::new(BooleanArray::from_unary(
156+
array.as_primitive::<Decimal128Type>(),
157+
|_| false,
158+
)) as ArrayRef,
159+
Decimal256(_, _) => Arc::new(BooleanArray::from_unary(
160+
array.as_primitive::<Decimal256Type>(),
161+
|_| false,
162+
)) as ArrayRef,
163+
164+
Int8 => Arc::new(BooleanArray::from_unary(
165+
array.as_primitive::<Int8Type>(),
166+
|_| false,
167+
)) as ArrayRef,
168+
Int16 => Arc::new(BooleanArray::from_unary(
169+
array.as_primitive::<Int16Type>(),
170+
|_| false,
171+
)) as ArrayRef,
172+
Int32 => Arc::new(BooleanArray::from_unary(
173+
array.as_primitive::<Int32Type>(),
174+
|_| false,
175+
)) as ArrayRef,
176+
Int64 => Arc::new(BooleanArray::from_unary(
177+
array.as_primitive::<Int64Type>(),
178+
|_| false,
179+
)) as ArrayRef,
180+
UInt8 => Arc::new(BooleanArray::from_unary(
181+
array.as_primitive::<UInt8Type>(),
182+
|_| false,
183+
)) as ArrayRef,
184+
UInt16 => Arc::new(BooleanArray::from_unary(
185+
array.as_primitive::<UInt16Type>(),
186+
|_| false,
187+
)) as ArrayRef,
188+
UInt32 => Arc::new(BooleanArray::from_unary(
189+
array.as_primitive::<UInt32Type>(),
190+
|_| false,
191+
)) as ArrayRef,
192+
UInt64 => Arc::new(BooleanArray::from_unary(
193+
array.as_primitive::<UInt64Type>(),
194+
|_| false,
195+
)) as ArrayRef,
196+
197+
other => {
198+
return exec_err!(
199+
"Unsupported data type {other:?} for function {}",
200+
self.name()
201+
);
202+
}
203+
};
204+
205+
Ok(ColumnarValue::Array(arr))
206+
}
207+
}
117208
}
118209

119210
fn documentation(&self) -> Option<&Documentation> {

0 commit comments

Comments
 (0)