Skip to content

Commit f5a2ac3

Browse files
authored
fix: percentile_cont interpolation causes NaN for f16 input (#20208)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #18945 ## Rationale for this change percentile_cont interpolation for Float16 could overflow f16 intermediates (e.g. when scaling the fractional component), producing inf/NaN and incorrect results. This PR makes interpolation numerically safe for f16. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? • Perform percentile interpolation in f64 and cast back to the input float type (f16/f32/f64) to avoid f16 overflow. • Add a regression unit test covering Float16 interpolation near the maximum finite value. <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? Yes. percentile_cont on Float16 inputs no longer returns NaN due to interpolation overflow and will produce correct finite results for valid finite f16 data <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 4ad5c3d commit f5a2ac3

4 files changed

Lines changed: 72 additions & 25 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/functions-aggregate/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ datafusion-physical-expr = { workspace = true }
5353
datafusion-physical-expr-common = { workspace = true }
5454
half = { workspace = true }
5555
log = { workspace = true }
56+
num-traits = { workspace = true }
5657
paste = { workspace = true }
5758

5859
[dev-dependencies]

datafusion/functions-aggregate/src/percentile_cont.rs

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ use arrow::array::{
2626
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
2727
use arrow::{
2828
array::{Array, ArrayRef, AsArray},
29-
datatypes::{
30-
ArrowNativeType, DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type,
31-
},
29+
datatypes::{DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type},
3230
};
3331

32+
use num_traits::AsPrimitive;
33+
3434
use arrow::array::ArrowNativeTypeOp;
3535
use datafusion_common::internal_err;
3636
use datafusion_common::types::{NativeType, logical_float64};
@@ -68,7 +68,10 @@ use crate::utils::validate_percentile_expr;
6868
/// The interpolation formula: `lower + (upper - lower) * fraction`
6969
/// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / PRECISION`
7070
/// to avoid floating-point operations on integer types while maintaining precision.
71-
const INTERPOLATION_PRECISION: usize = 1_000_000;
71+
///
72+
/// The interpolation arithmetic is performed in f64 and then cast back to the
73+
/// native type to avoid overflowing Float16 intermediates.
74+
const INTERPOLATION_PRECISION: f64 = 1_000_000.0;
7275

7376
create_func!(PercentileCont, percentile_cont_udaf);
7477

@@ -389,7 +392,12 @@ impl<T: ArrowNumericType + Debug> PercentileContAccumulator<T> {
389392
}
390393
}
391394

392-
impl<T: ArrowNumericType + Debug> Accumulator for PercentileContAccumulator<T> {
395+
impl<T> Accumulator for PercentileContAccumulator<T>
396+
where
397+
T: ArrowNumericType + Debug,
398+
T::Native: Copy + AsPrimitive<f64>,
399+
f64: AsPrimitive<T::Native>,
400+
{
393401
fn state(&mut self) -> Result<Vec<ScalarValue>> {
394402
// Convert `all_values` to `ListArray` and return a single List ScalarValue
395403

@@ -493,8 +501,11 @@ impl<T: ArrowNumericType + Send> PercentileContGroupsAccumulator<T> {
493501
}
494502
}
495503

496-
impl<T: ArrowNumericType + Send> GroupsAccumulator
497-
for PercentileContGroupsAccumulator<T>
504+
impl<T> GroupsAccumulator for PercentileContGroupsAccumulator<T>
505+
where
506+
T: ArrowNumericType + Send,
507+
T::Native: Copy + AsPrimitive<f64>,
508+
f64: AsPrimitive<T::Native>,
498509
{
499510
fn update_batch(
500511
&mut self,
@@ -673,7 +684,12 @@ impl<T: ArrowNumericType + Debug> DistinctPercentileContAccumulator<T> {
673684
}
674685
}
675686

676-
impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumulator<T> {
687+
impl<T> Accumulator for DistinctPercentileContAccumulator<T>
688+
where
689+
T: ArrowNumericType + Debug,
690+
T::Native: Copy + AsPrimitive<f64>,
691+
f64: AsPrimitive<T::Native>,
692+
{
677693
fn state(&mut self) -> Result<Vec<ScalarValue>> {
678694
self.distinct_values.state()
679695
}
@@ -728,7 +744,11 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctPercentileContAccumula
728744
fn calculate_percentile<T: ArrowNumericType>(
729745
values: &mut [T::Native],
730746
percentile: f64,
731-
) -> Option<T::Native> {
747+
) -> Option<T::Native>
748+
where
749+
T::Native: Copy + AsPrimitive<f64>,
750+
f64: AsPrimitive<T::Native>,
751+
{
732752
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
733753

734754
let len = values.len();
@@ -772,22 +792,47 @@ fn calculate_percentile<T: ArrowNumericType>(
772792
let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp);
773793
let upper_value = *upper_value;
774794

775-
// Linear interpolation using wrapping arithmetic
776-
// We use wrapping operations here (matching the approach in median.rs) because:
777-
// 1. Both values come from the input data, so diff is bounded by the value range
778-
// 2. fraction is between 0 and 1, and INTERPOLATION_PRECISION is small enough
779-
// to prevent overflow when combined with typical numeric ranges
780-
// 3. The result is guaranteed to be between lower_value and upper_value
781-
// 4. For floating-point types, wrapping ops behave the same as standard ops
795+
// Linear interpolation.
796+
// We compute a quantized interpolation weight using `INTERPOLATION_PRECISION` because:
797+
// 1. Both values come from the input data, so (upper - lower) is bounded by the value range
798+
// 2. fraction is between 0 and 1; quantizing it provides stable, predictable results
799+
// 3. The result is guaranteed to be between lower_value and upper_value (modulo cast rounding)
800+
// 4. Arithmetic is performed in f64 and cast back to avoid overflowing Float16 intermediates
782801
let fraction = index - (lower_index as f64);
783-
let diff = upper_value.sub_wrapping(lower_value);
784-
let interpolated = lower_value.add_wrapping(
785-
diff.mul_wrapping(T::Native::usize_as(
786-
(fraction * INTERPOLATION_PRECISION as f64) as usize,
787-
))
788-
.div_wrapping(T::Native::usize_as(INTERPOLATION_PRECISION)),
789-
);
790-
Some(interpolated)
802+
let scaled = (fraction * INTERPOLATION_PRECISION) as usize;
803+
let weight = scaled as f64 / INTERPOLATION_PRECISION;
804+
805+
let lower_f: f64 = lower_value.as_();
806+
let upper_f: f64 = upper_value.as_();
807+
let interpolated_f = lower_f + (upper_f - lower_f) * weight;
808+
Some(interpolated_f.as_())
791809
}
792810
}
793811
}
812+
813+
#[cfg(test)]
814+
mod tests {
815+
use super::calculate_percentile;
816+
use half::f16;
817+
818+
#[test]
819+
fn f16_interpolation_does_not_overflow_to_nan() {
820+
// Regression test for https://github.com/apache/datafusion/issues/18945
821+
// Interpolating between 0 and the max finite f16 value previously overflowed
822+
// intermediate f16 computations and produced NaN.
823+
let mut values = vec![f16::from_f32(0.0), f16::from_f32(65504.0)];
824+
let result =
825+
calculate_percentile::<arrow::datatypes::Float16Type>(&mut values, 0.5)
826+
.expect("non-empty input");
827+
let result_f = result.to_f32();
828+
assert!(
829+
!result_f.is_nan(),
830+
"expected non-NaN result, got {result_f}"
831+
);
832+
// 0.5 percentile should be close to midpoint
833+
assert!(
834+
(result_f - 32752.0).abs() < 1.0,
835+
"unexpected result {result_f}"
836+
);
837+
}
838+
}

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1331,7 +1331,7 @@ select
13311331
arrow_typeof(percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')))
13321332
from median_table;
13331333
----
1334-
NaN Float16
1334+
2.75 Float16
13351335

13361336
query RT
13371337
select

0 commit comments

Comments
 (0)