Skip to content

Commit e0cc868

Browse files
committed
fix: preserve duplicate GROUPING SETS rows
1 parent 4010a55 commit e0cc868

3 files changed

Lines changed: 186 additions & 5 deletions

File tree

datafusion/core/tests/sql/aggregates/basic.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,57 @@ async fn count_aggregated_cube() -> Result<()> {
175175
Ok(())
176176
}
177177

178+
#[tokio::test]
179+
async fn duplicate_grouping_sets_are_preserved() -> Result<()> {
180+
let ctx = SessionContext::new();
181+
let schema = Arc::new(Schema::new(vec![
182+
Field::new("deptno", DataType::Int32, false),
183+
Field::new("job", DataType::Utf8, true),
184+
Field::new("sal", DataType::Int32, true),
185+
Field::new("comm", DataType::Int32, true),
186+
]));
187+
let batch = RecordBatch::try_new(
188+
Arc::clone(&schema),
189+
vec![
190+
Arc::new(Int32Array::from(vec![10, 20])),
191+
Arc::new(StringArray::from(vec![Some("CLERK"), Some("MANAGER")])),
192+
Arc::new(Int32Array::from(vec![1300, 3000])),
193+
Arc::new(Int32Array::from(vec![None, None])),
194+
],
195+
)?;
196+
let provider = MemTable::try_new(Arc::clone(&schema), vec![vec![batch]])?;
197+
ctx.register_table("dup_grouping_sets", Arc::new(provider))?;
198+
199+
let results = plan_and_collect(
200+
&ctx,
201+
"
202+
SELECT deptno, job, sal, sum(comm) AS sum_comm,
203+
grouping(deptno) AS deptno_flag,
204+
grouping(job) AS job_flag,
205+
grouping(sal) AS sal_flag
206+
FROM dup_grouping_sets
207+
GROUP BY GROUPING SETS ((deptno, job), (deptno, sal), (deptno, job))
208+
ORDER BY deptno, job, sal, deptno_flag, job_flag, sal_flag
209+
",
210+
)
211+
.await?;
212+
213+
assert_eq!(results.len(), 1);
214+
assert_snapshot!(batches_to_string(&results), @r"
215+
+--------+---------+------+----------+-------------+----------+----------+
216+
| deptno | job | sal | sum_comm | deptno_flag | job_flag | sal_flag |
217+
+--------+---------+------+----------+-------------+----------+----------+
218+
| 10 | CLERK | | | 0 | 0 | 1 |
219+
| 10 | CLERK | | | 0 | 0 | 1 |
220+
| 10 | | 1300 | | 0 | 1 | 0 |
221+
| 20 | MANAGER | | | 0 | 0 | 1 |
222+
| 20 | MANAGER | | | 0 | 0 | 1 |
223+
| 20 | | 3000 | | 0 | 1 | 0 |
224+
+--------+---------+------+----------+-------------+----------+----------+
225+
");
226+
Ok(())
227+
}
228+
178229
async fn run_count_distinct_integers_aggregated_scenario(
179230
partitions: Vec<Vec<(&str, u64)>>,
180231
) -> Result<Vec<RecordBatch>> {

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use crate::{
3838
use datafusion_common::config::ConfigOptions;
3939
use datafusion_physical_expr::utils::collect_columns;
4040
use parking_lot::Mutex;
41-
use std::collections::HashSet;
41+
use std::collections::{HashMap, HashSet};
4242

4343
use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
4444
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
@@ -1937,15 +1937,53 @@ fn evaluate_optional(
19371937
.collect()
19381938
}
19391939

1940-
fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1940+
fn group_id_array(
1941+
group: &[bool],
1942+
group_ordinal: usize,
1943+
batch: &RecordBatch,
1944+
) -> Result<ArrayRef> {
19411945
if group.len() > 64 {
19421946
return not_impl_err!(
19431947
"Grouping sets with more than 64 columns are not supported"
19441948
);
19451949
}
1950+
let width_bits = if group.len() <= 8 {
1951+
8
1952+
} else if group.len() <= 16 {
1953+
16
1954+
} else if group.len() <= 32 {
1955+
32
1956+
} else {
1957+
64
1958+
};
1959+
let extra_bits = width_bits - group.len();
1960+
if extra_bits == 0 && group_ordinal > 0 {
1961+
return not_impl_err!(
1962+
"Duplicate grouping sets with more than {} grouping columns are not supported",
1963+
width_bits
1964+
);
1965+
}
1966+
if extra_bits < usize::BITS as usize {
1967+
let max_group_ordinal = 1usize << extra_bits;
1968+
if group_ordinal >= max_group_ordinal {
1969+
return not_impl_err!(
1970+
"Duplicate grouping sets exceed the supported grouping id capacity"
1971+
);
1972+
}
1973+
}
19461974
let group_id = group.iter().fold(0u64, |acc, &is_null| {
19471975
(acc << 1) | if is_null { 1 } else { 0 }
19481976
});
1977+
let group_id = if group.len() == 64 {
1978+
if group_ordinal > 0 {
1979+
return not_impl_err!(
1980+
"Duplicate grouping sets with 64 grouping columns are not supported"
1981+
);
1982+
}
1983+
group_id
1984+
} else {
1985+
((group_ordinal as u64) << group.len()) | group_id
1986+
};
19491987
let num_rows = batch.num_rows();
19501988
if group.len() <= 8 {
19511989
Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
@@ -1972,6 +2010,7 @@ pub fn evaluate_group_by(
19722010
group_by: &PhysicalGroupBy,
19732011
batch: &RecordBatch,
19742012
) -> Result<Vec<Vec<ArrayRef>>> {
2013+
let mut group_ordinals: HashMap<Vec<bool>, usize> = HashMap::new();
19752014
let exprs = evaluate_expressions_to_arrays(
19762015
group_by.expr.iter().map(|(expr, _)| expr),
19772016
batch,
@@ -1985,6 +2024,10 @@ pub fn evaluate_group_by(
19852024
.groups
19862025
.iter()
19872026
.map(|group| {
2027+
let group_ordinal = group_ordinals.entry(group.clone()).or_insert(0);
2028+
let current_group_ordinal = *group_ordinal;
2029+
*group_ordinal += 1;
2030+
19882031
let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
19892032
group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
19902033
if *is_null {
@@ -1994,7 +2037,7 @@ pub fn evaluate_group_by(
19942037
}
19952038
}));
19962039
if !group_by.is_single() {
1997-
group_values.push(group_id_array(group, batch)?);
2040+
group_values.push(group_id_array(group, current_group_ordinal, batch)?);
19982041
}
19992042
Ok(group_values)
20002043
})
@@ -2018,8 +2061,8 @@ mod tests {
20182061
use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
20192062

20202063
use arrow::array::{
2021-
DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray,
2022-
UInt32Array, UInt64Array,
2064+
DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray,
2065+
StructArray, UInt32Array, UInt64Array,
20232066
};
20242067
use arrow::compute::{SortOptions, concat_batches};
20252068
use arrow::datatypes::{DataType, Int32Type};
@@ -3478,6 +3521,71 @@ mod tests {
34783521
Ok(())
34793522
}
34803523

3524+
#[tokio::test]
3525+
async fn grouping_sets_preserve_duplicate_groups() -> Result<()> {
3526+
let schema = Arc::new(Schema::new(vec![
3527+
Field::new("deptno", DataType::Int32, false),
3528+
Field::new("job", DataType::Utf8, true),
3529+
Field::new("sal", DataType::Float64, true),
3530+
Field::new("comm", DataType::Float64, true),
3531+
]));
3532+
3533+
let input = TestMemoryExec::try_new_exec(
3534+
&[vec![RecordBatch::try_new(
3535+
Arc::clone(&schema),
3536+
vec![
3537+
Arc::new(Int32Array::from(vec![10, 20])),
3538+
Arc::new(StringArray::from(vec![Some("CLERK"), Some("MANAGER")])),
3539+
Arc::new(Float64Array::from(vec![1300.0, 3000.0])),
3540+
Arc::new(Float64Array::from(vec![None, None])),
3541+
],
3542+
)?]],
3543+
Arc::clone(&schema),
3544+
None,
3545+
)?;
3546+
3547+
let group_by = PhysicalGroupBy::new(
3548+
vec![
3549+
(col("deptno", &schema)?, "deptno".to_string()),
3550+
(col("job", &schema)?, "job".to_string()),
3551+
(col("sal", &schema)?, "sal".to_string()),
3552+
],
3553+
vec![
3554+
(lit(ScalarValue::Int32(None)), "deptno".to_string()),
3555+
(lit(ScalarValue::Utf8(None)), "job".to_string()),
3556+
(lit(ScalarValue::Float64(None)), "sal".to_string()),
3557+
],
3558+
vec![
3559+
vec![false, false, true],
3560+
vec![false, true, false],
3561+
vec![false, false, true],
3562+
],
3563+
true,
3564+
);
3565+
3566+
let aggr_exprs: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
3567+
AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
3568+
.schema(Arc::clone(&schema))
3569+
.alias("COUNT(1)")
3570+
.build()?,
3571+
)];
3572+
3573+
let aggregate_exec = Arc::new(AggregateExec::try_new(
3574+
AggregateMode::Single,
3575+
group_by,
3576+
aggr_exprs,
3577+
vec![None],
3578+
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
3579+
Arc::clone(&schema),
3580+
)?);
3581+
3582+
let output =
3583+
collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
3584+
let batch = concat_batches(&output[0].schema(), &output)?;
3585+
assert_eq!(batch.num_rows(), 6);
3586+
Ok(())
3587+
}
3588+
34813589
// test for https://github.com/apache/datafusion/issues/13949
34823590
async fn run_test_with_spill_pool_if_necessary(
34833591
pool_size: usize,

datafusion/sqllogictest/test_files/group_by.slt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5203,6 +5203,28 @@ NULL NULL 1
52035203
statement ok
52045204
drop table t;
52055205

5206+
# regression: duplicate grouping sets must not be collapsed into one
5207+
statement ok
5208+
create table duplicate_grouping_sets(deptno int, job varchar, sal int, comm int) as values
5209+
(10, 'CLERK', 1300, null),
5210+
(20, 'MANAGER', 3000, null);
5211+
5212+
query IT?I?I?III
5213+
select deptno, job, sal, sum(comm), grouping(deptno), grouping(job), grouping(sal)
5214+
from duplicate_grouping_sets
5215+
group by grouping sets ((deptno, job), (deptno, sal), (deptno, job))
5216+
order by deptno, job, sal, grouping(deptno), grouping(job), grouping(sal);
5217+
----
5218+
10 CLERK NULL NULL 0 0 1
5219+
10 CLERK NULL NULL 0 0 1
5220+
10 NULL 1300 NULL 0 1 0
5221+
20 MANAGER NULL NULL 0 0 1
5222+
20 MANAGER NULL NULL 0 0 1
5223+
20 NULL 3000 NULL 0 1 0
5224+
5225+
statement ok
5226+
drop table duplicate_grouping_sets;
5227+
52065228
# test multi group by for binary type without nulls
52075229
statement ok
52085230
create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, 0xb), (3, 0xb), (3, 0xb);

0 commit comments

Comments
 (0)