Skip to content

Commit 933657e

Browse files
authored
feat: Support sliding window queries for MedianAccumulator by implementing retract_batch (apache#19278)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes apache#123` indicates that this PR will close issue apache#123. --> - Closes apache#7664 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added tests ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> Computing the median() window is now supported instead of throwing an error <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent fc88240 commit 933657e

3 files changed

Lines changed: 146 additions & 27 deletions

File tree

datafusion/core/tests/dataframe/mod.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,26 +1102,26 @@ async fn window_using_aggregates() -> Result<()> {
11021102
| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |
11031103
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
11041104
| | | | | | | | 1 | -85 |
1105-
| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |
1106-
| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |
1107-
| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |
1108-
| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |
1109-
| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |
1110-
| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |
1105+
| -85 | -101 | 14 | -12 | -12 | 83 | -101 | 4 | -54 |
1106+
| -85 | -101 | 17 | -25 | -25 | 83 | -101 | 5 | -31 |
1107+
| -85 | -12 | 10 | -32 | -34 | 83 | -85 | 3 | 13 |
1108+
| -85 | -25 | 3 | -56 | -56 | -25 | -85 | 1 | -5 |
1109+
| -85 | -31 | 18 | -29 | -28 | 83 | -101 | 5 | 36 |
1110+
| -85 | -38 | 16 | -25 | -25 | 83 | -101 | 4 | 65 |
11111111
| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |
1112-
| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |
1113-
| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |
1114-
| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |
1115-
| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |
1116-
| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |
1112+
| -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 |
1113+
| -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 |
1114+
| -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 |
1115+
| -85 | -56 | 2 | -70 | 57 | -56 | -85 | 1 | -25 |
1116+
| -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 |
11171117
| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |
1118-
| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |
1119-
| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |
1120-
| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |
1121-
| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |
1122-
| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |
1123-
| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |
1124-
| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |
1118+
| -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 |
1119+
| -85 | 13 | 11 | -25 | -25 | 83 | -85 | 3 | 13 |
1120+
| -85 | 14 | 12 | -12 | -12 | 83 | -85 | 3 | 17 |
1121+
| -85 | 17 | 13 | -11 | -8 | 83 | -85 | 4 | -101 |
1122+
| -85 | 45 | 8 | -34 | -34 | 83 | -85 | 3 | -72 |
1123+
| -85 | 65 | 17 | -17 | -18 | 83 | -101 | 5 | -101 |
1124+
| -85 | 83 | 5 | -25 | -25 | 83 | -85 | 2 | -48 |
11251125
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
11261126
"###
11271127
);

datafusion/functions-aggregate/src/median.rs

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumu
5353
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
5454
use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
5555
use datafusion_macros::user_doc;
56+
use std::collections::HashMap;
5657

5758
make_udaf_expr_and_func!(
5859
Median,
@@ -289,14 +290,51 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
289290
}
290291

291292
fn evaluate(&mut self) -> Result<ScalarValue> {
292-
let d = std::mem::take(&mut self.all_values);
293-
let median = calculate_median::<T>(d);
293+
let median = calculate_median::<T>(&mut self.all_values);
294294
ScalarValue::new_primitive::<T>(median, &self.data_type)
295295
}
296296

297297
fn size(&self) -> usize {
298298
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
299299
}
300+
301+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
302+
let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
303+
304+
let arr = &values[0];
305+
for i in 0..arr.len() {
306+
let v = ScalarValue::try_from_array(arr, i)?;
307+
if !v.is_null() {
308+
*to_remove.entry(v).or_default() += 1;
309+
}
310+
}
311+
312+
let mut i = 0;
313+
while i < self.all_values.len() {
314+
let k = ScalarValue::new_primitive::<T>(
315+
Some(self.all_values[i]),
316+
&self.data_type,
317+
)?;
318+
if let Some(count) = to_remove.get_mut(&k)
319+
&& *count > 0
320+
{
321+
self.all_values.swap_remove(i);
322+
*count -= 1;
323+
if *count == 0 {
324+
to_remove.remove(&k);
325+
if to_remove.is_empty() {
326+
break;
327+
}
328+
}
329+
}
330+
i += 1;
331+
}
332+
Ok(())
333+
}
334+
335+
fn supports_retract_batch(&self) -> bool {
336+
true
337+
}
300338
}
301339

302340
/// The median groups accumulator accumulates the raw input values
@@ -443,8 +481,8 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T
443481
// Calculate median for each group
444482
let mut evaluate_result_builder =
445483
PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
446-
for values in emit_group_values {
447-
let median = calculate_median::<T>(values);
484+
for mut values in emit_group_values {
485+
let median = calculate_median::<T>(&mut values);
448486
evaluate_result_builder.append_option(median);
449487
}
450488

@@ -528,11 +566,11 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctMedianAccumulator<T> {
528566
}
529567

530568
fn evaluate(&mut self) -> Result<ScalarValue> {
531-
let d = std::mem::take(&mut self.distinct_values.values)
569+
let mut d = std::mem::take(&mut self.distinct_values.values)
532570
.into_iter()
533571
.map(|v| v.0)
534572
.collect::<Vec<_>>();
535-
let median = calculate_median::<T>(d);
573+
let median = calculate_median::<T>(&mut d);
536574
ScalarValue::new_primitive::<T>(median, &self.data_type)
537575
}
538576

@@ -556,9 +594,7 @@ where
556594
.unwrap()
557595
}
558596

559-
fn calculate_median<T: ArrowNumericType>(
560-
mut values: Vec<T::Native>,
561-
) -> Option<T::Native> {
597+
fn calculate_median<T: ArrowNumericType>(values: &mut [T::Native]) -> Option<T::Native> {
562598
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
563599

564600
let len = values.len();

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,89 @@ SELECT approx_median(col_f64_nan) FROM median_table
991991
----
992992
NaN
993993

994+
# median_sliding_window
995+
statement ok
996+
CREATE TABLE median_window_test (
997+
timestamp INT,
998+
tags VARCHAR,
999+
value DOUBLE
1000+
);
1001+
1002+
statement ok
1003+
INSERT INTO median_window_test (timestamp, tags, value) VALUES
1004+
(1, 'tag1', 10.0),
1005+
(2, 'tag1', 20.0),
1006+
(3, 'tag1', 30.0),
1007+
(4, 'tag1', 40.0),
1008+
(5, 'tag1', 50.0),
1009+
(1, 'tag2', 60.0),
1010+
(2, 'tag2', 70.0),
1011+
(3, 'tag2', 80.0),
1012+
(4, 'tag2', 90.0),
1013+
(5, 'tag2', 100.0);
1014+
1015+
query ITRR
1016+
SELECT
1017+
timestamp,
1018+
tags,
1019+
value,
1020+
median(value) OVER (
1021+
PARTITION BY tags
1022+
ORDER BY timestamp
1023+
ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1024+
) AS value_median_3
1025+
FROM median_window_test
1026+
ORDER BY tags, timestamp;
1027+
----
1028+
1 tag1 10 15
1029+
2 tag1 20 20
1030+
3 tag1 30 30
1031+
4 tag1 40 40
1032+
5 tag1 50 45
1033+
1 tag2 60 65
1034+
2 tag2 70 70
1035+
3 tag2 80 80
1036+
4 tag2 90 90
1037+
5 tag2 100 95
1038+
1039+
# median_non_sliding_window
1040+
query ITRRRR
1041+
SELECT
1042+
timestamp,
1043+
tags,
1044+
value,
1045+
median(value) OVER (
1046+
PARTITION BY tags
1047+
ORDER BY timestamp
1048+
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
1049+
) AS value_median_unbounded_preceding,
1050+
median(value) OVER (
1051+
PARTITION BY tags
1052+
ORDER BY timestamp
1053+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
1054+
) AS value_median_unbounded_both,
1055+
median(value) OVER (
1056+
PARTITION BY tags
1057+
ORDER BY timestamp
1058+
ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING
1059+
) AS value_median_unbounded_following
1060+
FROM median_window_test
1061+
ORDER BY tags, timestamp;
1062+
----
1063+
1 tag1 10 10 30 30
1064+
2 tag1 20 15 30 35
1065+
3 tag1 30 20 30 40
1066+
4 tag1 40 25 30 45
1067+
5 tag1 50 30 30 50
1068+
1 tag2 60 60 80 80
1069+
2 tag2 70 65 80 85
1070+
3 tag2 80 70 80 90
1071+
4 tag2 90 75 80 95
1072+
5 tag2 100 80 80 100
1073+
1074+
statement ok
1075+
DROP TABLE median_window_test;
1076+
9941077
query RT
9951078
select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median(arrow_cast(col_f32, 'Float16'))) from median_table;
9961079
----

0 commit comments

Comments
 (0)