Skip to content

Commit 274eeb6

Browse files
author
yash
committed
feat: gate multi-distinct COUNT rewrite behind session config (default off)
✅ Add datafusion.optimizer.enable_multi_distinct_count_rewrite (default false). ✅ MultiDistinctCountRewrite no-ops when disabled; OptimizerContext::with_enable_multi_distinct_count_rewrite for tests. ✅ SQL integration tests enable the flag via session helper; unit test skips_rewrite_when_config_disabled. ✅ Document option in user-guide configs.md. ❌ Does not change rewrite semantics when enabled. Made-with: Cursor
1 parent a7348f6 commit 274eeb6

File tree

5 files changed

+67
-7
lines changed

5 files changed

+67
-7
lines changed

datafusion/common/src/config.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,11 @@ config_namespace! {
969969
/// predicate push down.
970970
pub filter_null_join_keys: bool, default = false
971971

972+
/// When `true`, rewrite one grouped aggregate that has multiple `COUNT(DISTINCT …)` into
973+
/// joins of per-distinct sub-aggregates (can lower peak memory; adds join work). Default
974+
/// `false` until workload benchmarks justify enabling broadly.
975+
pub enable_multi_distinct_count_rewrite: bool, default = false
976+
972977
/// Should DataFusion repartition data using the aggregate keys to execute aggregates
973978
/// in parallel using the provided `target_partitions` level
974979
pub repartition_aggregations: bool, default = true

datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,20 @@
2020
use super::*;
2121
use arrow::array::{Float64Array, Int32Array, StringArray};
2222
use datafusion::common::test_util::batches_to_sort_string;
23+
use datafusion::execution::config::SessionConfig;
24+
use datafusion::execution::context::SessionContext;
2325
use datafusion_catalog::MemTable;
2426

27+
fn session_with_multi_distinct_count_rewrite() -> SessionContext {
28+
SessionContext::new_with_config(SessionConfig::new().set_bool(
29+
"datafusion.optimizer.enable_multi_distinct_count_rewrite",
30+
true,
31+
))
32+
}
33+
2534
#[tokio::test]
2635
async fn multi_count_distinct_matches_expected_with_nulls() -> Result<()> {
27-
let ctx = SessionContext::new();
36+
let ctx = session_with_multi_distinct_count_rewrite();
2837
let schema = Arc::new(Schema::new(vec![
2938
Field::new("g", DataType::Int32, false),
3039
Field::new("b", DataType::Utf8, true),
@@ -60,7 +69,7 @@ async fn multi_count_distinct_matches_expected_with_nulls() -> Result<()> {
6069
/// `COUNT(*)` + two `COUNT(DISTINCT …)` per group (BI-style); must match non-rewritten semantics.
6170
#[tokio::test]
6271
async fn multi_count_distinct_with_count_star_matches_expected() -> Result<()> {
63-
let ctx = SessionContext::new();
72+
let ctx = session_with_multi_distinct_count_rewrite();
6473
let schema = Arc::new(Schema::new(vec![
6574
Field::new("g", DataType::Int32, false),
6675
Field::new("b", DataType::Int32, false),
@@ -96,7 +105,7 @@ async fn multi_count_distinct_with_count_star_matches_expected() -> Result<()> {
96105
/// Multiple `GROUP BY` keys: join must align on all keys.
97106
#[tokio::test]
98107
async fn multi_count_distinct_two_group_keys_matches_expected() -> Result<()> {
99-
let ctx = SessionContext::new();
108+
let ctx = session_with_multi_distinct_count_rewrite();
100109
let schema = Arc::new(Schema::new(vec![
101110
Field::new("g1", DataType::Int32, false),
102111
Field::new("g2", DataType::Int32, false),
@@ -136,7 +145,7 @@ async fn multi_count_distinct_two_group_keys_matches_expected() -> Result<()> {
136145
/// Two `COUNT(DISTINCT …)` so the rewrite applies; semantics match plain aggregation.
137146
#[tokio::test]
138147
async fn multi_count_distinct_lower_matches_expected_case_collapsing() -> Result<()> {
139-
let ctx = SessionContext::new();
148+
let ctx = session_with_multi_distinct_count_rewrite();
140149
let schema = Arc::new(Schema::new(vec![
141150
Field::new("g", DataType::Int32, false),
142151
Field::new("b", DataType::Utf8, false),
@@ -173,7 +182,7 @@ async fn multi_count_distinct_lower_matches_expected_case_collapsing() -> Result
173182
/// Exercises the same “expression in distinct, not raw column” path as `CAST` in the rule.
174183
#[tokio::test]
175184
async fn multi_count_distinct_cast_float_to_int_collapses_nearby_values() -> Result<()> {
176-
let ctx = SessionContext::new();
185+
let ctx = session_with_multi_distinct_count_rewrite();
177186
let schema = Arc::new(Schema::new(vec![
178187
Field::new("g", DataType::Int32, false),
179188
Field::new("x", DataType::Float64, false),

datafusion/optimizer/src/multi_distinct_count_rewrite.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ impl OptimizerRule for MultiDistinctCountRewrite {
113113
plan: LogicalPlan,
114114
config: &dyn OptimizerConfig,
115115
) -> Result<Transformed<LogicalPlan>> {
116+
if !config
117+
.options()
118+
.optimizer
119+
.enable_multi_distinct_count_rewrite
120+
{
121+
return Ok(Transformed::no(plan));
122+
}
123+
116124
let LogicalPlan::Aggregate(Aggregate {
117125
input,
118126
aggr_expr,
@@ -351,17 +359,27 @@ mod tests {
351359
use datafusion_expr::{Expr, col};
352360
use datafusion_functions_aggregate::expr_fn::{count, count_distinct};
353361

354-
fn optimize_with_rule(
362+
fn optimize_with_rule_config(
355363
plan: LogicalPlan,
356364
rule: Arc<dyn OptimizerRule + Send + Sync>,
365+
enable_multi_distinct_count_rewrite: bool,
357366
) -> Result<LogicalPlan> {
358367
Optimizer::with_rules(vec![rule]).optimize(
359368
plan,
360-
&OptimizerContext::new(),
369+
&OptimizerContext::new().with_enable_multi_distinct_count_rewrite(
370+
enable_multi_distinct_count_rewrite,
371+
),
361372
|_, _| {},
362373
)
363374
}
364375

376+
fn optimize_with_rule(
377+
plan: LogicalPlan,
378+
rule: Arc<dyn OptimizerRule + Send + Sync>,
379+
) -> Result<LogicalPlan> {
380+
optimize_with_rule_config(plan, rule, true)
381+
}
382+
365383
#[test]
366384
fn rewrites_two_count_distinct() -> Result<()> {
367385
let table_scan = test_table_scan()?;
@@ -585,6 +603,25 @@ mod tests {
585603
Ok(())
586604
}
587605

606+
#[test]
607+
fn skips_rewrite_when_config_disabled() -> Result<()> {
608+
let table_scan = test_table_scan()?;
609+
let plan = LogicalPlanBuilder::from(table_scan)
610+
.aggregate(
611+
vec![col("a")],
612+
vec![count_distinct(col("b")), count_distinct(col("c"))],
613+
)?
614+
.build()?;
615+
let before = plan.display_indent_schema().to_string();
616+
let optimized = optimize_with_rule_config(
617+
plan,
618+
Arc::new(MultiDistinctCountRewrite::new()),
619+
false,
620+
)?;
621+
assert_eq!(before, optimized.display_indent_schema().to_string());
622+
Ok(())
623+
}
624+
588625
#[test]
589626
fn does_not_rewrite_mixed_agg() -> Result<()> {
590627
let table_scan = test_table_scan()?;

datafusion/optimizer/src/optimizer.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,14 @@ impl OptimizerContext {
218218
Arc::make_mut(&mut self.options).optimizer.max_passes = v as usize;
219219
self
220220
}
221+
222+
/// Enable [`crate::multi_distinct_count_rewrite::MultiDistinctCountRewrite`] (default off).
223+
pub fn with_enable_multi_distinct_count_rewrite(mut self, enable: bool) -> Self {
224+
Arc::make_mut(&mut self.options)
225+
.optimizer
226+
.enable_multi_distinct_count_rewrite = enable;
227+
self
228+
}
221229
}
222230

223231
impl Default for OptimizerContext {

docs/source/user-guide/configs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ The following configuration settings are available:
144144
| datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown | true | When set to true, the optimizer will attempt to push down Aggregate dynamic filters into the file scan phase. |
145145
| datafusion.optimizer.enable_dynamic_filter_pushdown | true | When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. |
146146
| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. |
147+
| datafusion.optimizer.enable_multi_distinct_count_rewrite | false | When set to true, the optimizer may rewrite a single aggregate with multiple `COUNT(DISTINCT …)` (with `GROUP BY`) into joins of per-distinct sub-aggregates. This can reduce peak memory but adds join work; default off until benchmarks support enabling broadly. |
147148
| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level |
148149
| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. |
149150
| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level |

0 commit comments

Comments
 (0)