Skip to content

Commit cad3865

Browse files
authored
fix: correct weight handling in approx_percentile_cont_with_weight (#19941)
The approx_percentile_cont_with_weight function was producing incorrect results due to wrong weight handling in the TDigest implementation. Root cause: In TDigest::new_with_centroid(), the count field was hardcoded to 1 regardless of the actual centroid weight, while the weight was correctly used in the sum calculation. This mismatch caused incorrect percentile calculations since estimate_quantile() uses count to compute the rank. Changes: - Changed TDigest::count from u64 to f64 to properly support fractional weights (consistent with ClickHouse's TDigest implementation) - Fixed new_with_centroid() to use centroid.weight for count - Updated state_fields() in approx_percentile_cont and approx_median to use Float64 for the count field - Added early return in merge_digests() when all centroids have zero weight to prevent panic - Updated test expectations to reflect correct weighted percentile behavior ## Which issue does this PR close? - Closes #19940 ## Rationale for this change The `approx_percentile_cont_with_weight` function produces incorrect weighted percentile results. The bug is in the TDigest implementation where `new_with_centroid()` sets `count: 1` regardless of the actual centroid weight, while the weight is used elsewhere in centroid merging. This mismatch corrupts the percentile calculation. ## What changes are included in this PR? - Changed `TDigest::count` from `u64` to `f64` to properly support fractional weights (consistent with [ClickHouse's TDigest implementation](https://github.com/ClickHouse/ClickHouse/blob/927af1255adb37ace1b95cc3ec4316553b4cb4b4/src/AggregateFunctions/QuantileTDigest.h#L71-L87)) - Fixed `new_with_centroid()` to use `centroid.weight` for count - Updated `state_fields()` in `approx_percentile_cont` and `approx_median` to use `Float64` for the count field - Added early return in `merge_digests()` when all centroids have zero weight to prevent panic - Updated test expectations to reflect correct weighted percentile behavior ## Are these changes tested? Yes. - All existing unit tests in tdigest.rs pass (7 tests) - All SQL logic tests for aggregate functions pass - Manual testing confirms correct behavior with various weight distributions (equal weights, heavy low/high values, linear weights, fractional weights) ## Are there any user-facing changes? Yes, this is a breaking change: 1. Result changes: approx_percentile_cont_with_weight now returns correct weighted percentiles. Queries relying on the previous (incorrect) behavior will see different results. 2. Serialized state format change: The TDigest state field count changes from UInt64 to Float64. Any existing serialized/checkpointed TDigest state will be incompatible and cannot be restored. 3. Edge case behavior change: When all weights are zero, the function now returns NULL instead of the previous undefined behavior.
1 parent f0de02f commit cad3865

4 files changed

Lines changed: 40 additions & 45 deletions

File tree

datafusion/functions-aggregate-common/src/tdigest.rs

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,6 @@ macro_rules! cast_scalar_f64 {
4949
};
5050
}
5151

52-
// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or
53-
// panic.
54-
macro_rules! cast_scalar_u64 {
55-
($value:expr ) => {
56-
match &$value {
57-
ScalarValue::UInt64(Some(v)) => *v,
58-
v => panic!("invalid type {}", v),
59-
}
60-
};
61-
}
62-
6352
/// Centroid implementation to the cluster mentioned in the paper.
6453
#[derive(Debug, PartialEq, Clone)]
6554
pub struct Centroid {
@@ -110,7 +99,7 @@ pub struct TDigest {
11099
centroids: Vec<Centroid>,
111100
max_size: usize,
112101
sum: f64,
113-
count: u64,
102+
count: f64,
114103
max: f64,
115104
min: f64,
116105
}
@@ -120,8 +109,8 @@ impl TDigest {
120109
TDigest {
121110
centroids: Vec::new(),
122111
max_size,
123-
sum: 0_f64,
124-
count: 0,
112+
sum: 0.0,
113+
count: 0.0,
125114
max: f64::NAN,
126115
min: f64::NAN,
127116
}
@@ -133,14 +122,14 @@ impl TDigest {
133122
centroids: vec![centroid.clone()],
134123
max_size,
135124
sum: centroid.mean * centroid.weight,
136-
count: 1,
125+
count: centroid.weight,
137126
max: centroid.mean,
138127
min: centroid.mean,
139128
}
140129
}
141130

142131
#[inline]
143-
pub fn count(&self) -> u64 {
132+
pub fn count(&self) -> f64 {
144133
self.count
145134
}
146135

@@ -170,8 +159,8 @@ impl Default for TDigest {
170159
TDigest {
171160
centroids: Vec::new(),
172161
max_size: 100,
173-
sum: 0_f64,
174-
count: 0,
162+
sum: 0.0,
163+
count: 0.0,
175164
max: f64::NAN,
176165
min: f64::NAN,
177166
}
@@ -216,12 +205,12 @@ impl TDigest {
216205
}
217206

218207
let mut result = TDigest::new(self.max_size());
219-
result.count = self.count() + sorted_values.len() as u64;
208+
result.count = self.count() + sorted_values.len() as f64;
220209

221210
let maybe_min = *sorted_values.first().unwrap();
222211
let maybe_max = *sorted_values.last().unwrap();
223212

224-
if self.count() > 0 {
213+
if self.count() > 0.0 {
225214
result.min = self.min.min(maybe_min);
226215
result.max = self.max.max(maybe_max);
227216
} else {
@@ -233,7 +222,7 @@ impl TDigest {
233222

234223
let mut k_limit: u64 = 1;
235224
let mut q_limit_times_count =
236-
Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
225+
Self::k_to_q(k_limit, self.max_size) * result.count();
237226
k_limit += 1;
238227

239228
let mut iter_centroids = self.centroids.iter().peekable();
@@ -281,7 +270,7 @@ impl TDigest {
281270

282271
compressed.push(curr.clone());
283272
q_limit_times_count =
284-
Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
273+
Self::k_to_q(k_limit, self.max_size) * result.count();
285274
k_limit += 1;
286275
curr = next;
287276
}
@@ -353,7 +342,7 @@ impl TDigest {
353342
let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
354343
let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
355344

356-
let mut count = 0;
345+
let mut count = 0.0;
357346
let mut min = f64::INFINITY;
358347
let mut max = f64::NEG_INFINITY;
359348

@@ -362,7 +351,7 @@ impl TDigest {
362351
starts.push(start);
363352

364353
let curr_count = digest.count();
365-
if curr_count > 0 {
354+
if curr_count > 0.0 {
366355
min = min.min(digest.min);
367356
max = max.max(digest.max);
368357
count += curr_count;
@@ -373,6 +362,11 @@ impl TDigest {
373362
}
374363
}
375364

365+
// If no centroids were added (all digests had zero count), return default
366+
if centroids.is_empty() {
367+
return TDigest::default();
368+
}
369+
376370
let mut digests_per_block: usize = 1;
377371
while digests_per_block < starts.len() {
378372
for i in (0..starts.len()).step_by(digests_per_block * 2) {
@@ -397,7 +391,7 @@ impl TDigest {
397391
let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
398392

399393
let mut k_limit = 1;
400-
let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
394+
let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count;
401395

402396
let mut iter_centroids = centroids.iter_mut();
403397
let mut curr = iter_centroids.next().unwrap();
@@ -416,7 +410,7 @@ impl TDigest {
416410
sums_to_merge = 0_f64;
417411
weights_to_merge = 0_f64;
418412
compressed.push(curr.clone());
419-
q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
413+
q_limit_times_count = Self::k_to_q(k_limit, max_size) * count;
420414
k_limit += 1;
421415
curr = centroid;
422416
}
@@ -440,7 +434,7 @@ impl TDigest {
440434
return 0.0;
441435
}
442436

443-
let rank = q * self.count as f64;
437+
let rank = q * self.count;
444438

445439
let mut pos: usize;
446440
let mut t;
@@ -450,7 +444,7 @@ impl TDigest {
450444
}
451445

452446
pos = 0;
453-
t = self.count as f64;
447+
t = self.count;
454448

455449
for (k, centroid) in self.centroids.iter().enumerate().rev() {
456450
t -= centroid.weight();
@@ -563,7 +557,7 @@ impl TDigest {
563557
vec![
564558
ScalarValue::UInt64(Some(self.max_size as u64)),
565559
ScalarValue::Float64(Some(self.sum)),
566-
ScalarValue::UInt64(Some(self.count)),
560+
ScalarValue::Float64(Some(self.count)),
567561
ScalarValue::Float64(Some(self.max)),
568562
ScalarValue::Float64(Some(self.min)),
569563
ScalarValue::List(arr),
@@ -611,7 +605,7 @@ impl TDigest {
611605
Self {
612606
max_size,
613607
sum: cast_scalar_f64!(state[1]),
614-
count: cast_scalar_u64!(&state[2]),
608+
count: cast_scalar_f64!(state[2]),
615609
max,
616610
min,
617611
centroids,

datafusion/functions-aggregate/src/approx_median.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ impl AggregateUDFImpl for ApproxMedian {
110110
Ok(vec![
111111
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
112112
Field::new(format_state_name(args.name, "sum"), Float64, false),
113-
Field::new(format_state_name(args.name, "count"), UInt64, false),
113+
Field::new(format_state_name(args.name, "count"), Float64, false),
114114
Field::new(format_state_name(args.name, "max"), Float64, false),
115115
Field::new(format_state_name(args.name, "min"), Float64, false),
116116
Field::new_list(

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ impl AggregateUDFImpl for ApproxPercentileCont {
259259
),
260260
Field::new(
261261
format_state_name(args.name, "count"),
262-
DataType::UInt64,
262+
DataType::Float64,
263263
false,
264264
),
265265
Field::new(
@@ -436,7 +436,7 @@ impl Accumulator for ApproxPercentileAccumulator {
436436
}
437437

438438
fn evaluate(&mut self) -> Result<ScalarValue> {
439-
if self.digest.count() == 0 {
439+
if self.digest.count() == 0.0 {
440440
return ScalarValue::try_from(self.return_type.clone());
441441
}
442442
let q = self.digest.estimate_quantile(self.percentile);
@@ -513,8 +513,8 @@ mod tests {
513513
ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
514514

515515
accumulator.merge_digests(&[t1]);
516-
assert_eq!(accumulator.digest.count(), 50_000);
516+
assert_eq!(accumulator.digest.count(), 50_000.0);
517517
accumulator.merge_digests(&[t2]);
518-
assert_eq!(accumulator.digest.count(), 100_000);
518+
assert_eq!(accumulator.digest.count(), 100_000.0);
519519
}
520520
}

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,11 +2029,12 @@ statement ok
20292029
INSERT INTO t1 VALUES (TRUE);
20302030

20312031
# ISSUE: https://github.com/apache/datafusion/issues/12716
2032-
# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf'
2032+
# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN'
2033+
# With weight=0, the data point does not contribute, so result is NULL
20332034
query R
20342035
SELECT approx_percentile_cont_with_weight(0, 0) WITHIN GROUP (ORDER BY 'NaN'::DOUBLE) FROM t1 WHERE t1.v1;
20352036
----
2036-
Infinity
2037+
NULL
20372038

20382039
statement ok
20392040
DROP TABLE t1;
@@ -2352,21 +2353,21 @@ e 115
23522353
query TI
23532354
SELECT c1, approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
23542355
----
2355-
a 74
2356+
a 65
23562357
b 68
2357-
c 123
2358-
d 124
2359-
e 115
2358+
c 122
2359+
d 123
2360+
e 110
23602361

23612362
# approx_percentile_cont_with_weight with centroids
23622363
query TI
23632364
SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
23642365
----
2365-
a 74
2366+
a 65
23662367
b 68
2367-
c 123
2368-
d 124
2369-
e 115
2368+
c 122
2369+
d 123
2370+
e 110
23702371

23712372
# csv_query_sum_crossjoin
23722373
query TTI

0 commit comments

Comments
 (0)