Skip to content

Commit 59fa290

Browse files
author
yash
committed
perf: default multi COUNT(DISTINCT) logical optimizer rewrite
Add MultiDistinctCountRewrite in datafusion-optimizer and register it in Optimizer::new() after SingleDistinctToGroupBy. Rewrites 2+ simple COUNT(DISTINCT) on different args into a join of two-phase aggregates; filter distinct_arg IS NOT NULL on each branch for correct NULL semantics. ✅ Unit tests in datafusion-optimizer; ✅ SQL integration test (NULLs) in core_integration.
1 parent 5e54b89 commit 59fa290

5 files changed

Lines changed: 457 additions & 0 deletions

File tree

datafusion/core/tests/sql/aggregates/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,3 +1024,4 @@ pub fn split_fuzz_timestamp_data_into_batches(
10241024

10251025
pub mod basic;
10261026
pub mod dict_nulls;
1027+
pub mod multi_distinct_count_rewrite;
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! End-to-end SQL tests for the multi-`COUNT(DISTINCT)` logical optimizer rewrite.
19+
20+
use super::*;
21+
use arrow::array::{Int32Array, StringArray};
22+
use datafusion::common::test_util::batches_to_sort_string;
23+
use datafusion_catalog::MemTable;
24+
25+
#[tokio::test]
26+
async fn multi_count_distinct_matches_expected_with_nulls() -> Result<()> {
27+
let ctx = SessionContext::new();
28+
let schema = Arc::new(Schema::new(vec![
29+
Field::new("g", DataType::Int32, false),
30+
Field::new("b", DataType::Utf8, true),
31+
Field::new("c", DataType::Utf8, true),
32+
]));
33+
let batch = RecordBatch::try_new(
34+
schema.clone(),
35+
vec![
36+
Arc::new(Int32Array::from(vec![1, 1, 1])),
37+
Arc::new(StringArray::from(vec![Some("x"), None, Some("x")])),
38+
Arc::new(StringArray::from(vec![None, Some("y"), Some("y")])),
39+
],
40+
)?;
41+
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
42+
ctx.register_table("t", Arc::new(provider))?;
43+
44+
let sql =
45+
"SELECT g, COUNT(DISTINCT b) AS cb, COUNT(DISTINCT c) AS cc FROM t GROUP BY g";
46+
let batches = ctx.sql(sql).await?.collect().await?;
47+
let out = batches_to_sort_string(&batches);
48+
49+
assert_eq!(
50+
out,
51+
"+---+----+----+\n\
52+
| g | cb | cc |\n\
53+
+---+----+----+\n\
54+
| 1 | 1 | 1 |\n\
55+
+---+----+----+"
56+
);
57+
Ok(())
58+
}

datafusion/optimizer/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pub mod eliminate_outer_join;
5959
pub mod extract_equijoin_predicate;
6060
pub mod extract_leaf_expressions;
6161
pub mod filter_null_join_keys;
62+
pub mod multi_distinct_count_rewrite;
6263
pub mod optimize_projections;
6364
pub mod optimize_unions;
6465
pub mod optimizer;

0 commit comments

Comments
 (0)