diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 63962998ad18b..24ca33c0c2c90 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -590,6 +590,53 @@ pub fn get_required_group_by_exprs_indices( .collect() } +/// Returns indices for the minimal subset of ORDER BY expressions that are +/// functionally equivalent to the original set of ORDER BY expressions. +pub fn get_required_sort_exprs_indices( + schema: &DFSchema, + sort_expr_names: &[String], +) -> Vec { + let dependencies = schema.functional_dependencies(); + let field_names = schema.field_names(); + + let mut known_field_indices = HashSet::new(); + let mut required_sort_expr_indices = Vec::new(); + + for (sort_expr_idx, sort_expr_name) in sort_expr_names.iter().enumerate() { + // If the sort expression doesn't correspond to a known schema field + // (e.g. a computed expression), we can't reason about it via functional + // dependencies, so conservatively keep it. + let Some(field_idx) = field_names + .iter() + .position(|field_name| field_name == sort_expr_name) + else { + required_sort_expr_indices.push(sort_expr_idx); + continue; + }; + + // A sort expression is removable if its value is functionally determined + // by fields that already appear earlier in the sort order: if the earlier + // fields are fixed, this one's value is fixed too, so it adds no ordering + // information. + let removable = dependencies.deps.iter().any(|dependency| { + dependency.target_indices.contains(&field_idx) + && dependency + .source_indices + .iter() + .all(|source_idx| known_field_indices.contains(source_idx)) + }); + + if removable { + continue; + } + + known_field_indices.insert(field_idx); + required_sort_expr_indices.push(sort_expr_idx); + } + + required_sort_expr_indices +} + /// Updates entries inside the `entries` vector with their corresponding /// indices inside the `proj_indices` vector. fn update_elements_with_matching_indices( diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index fdd04f752455e..996c563f0d8a2 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -82,7 +82,7 @@ pub use file_options::file_type::{ pub use functional_dependencies::{ Constraint, Constraints, Dependency, FunctionalDependence, FunctionalDependencies, aggregate_functional_dependencies, get_required_group_by_exprs_indices, - get_target_functional_dependencies, + get_required_sort_exprs_indices, get_target_functional_dependencies, }; use hashbrown::DefaultHashBuilder; pub use join_type::{JoinConstraint, JoinSide, JoinType}; diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 113c92c2c8e99..97aa6e1d8480d 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -19,8 +19,8 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::Result; use datafusion_common::tree_node::Transformed; +use datafusion_common::{Result, get_required_sort_exprs_indices, internal_err}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; use std::hash::{Hash, Hasher}; @@ -76,12 +76,36 @@ impl OptimizerRule for EliminateDuplicatedExpr { .map(|wrapper| wrapper.0) .collect(); + let sort_expr_names = unique_exprs + .iter() + .map(|sort_expr| sort_expr.expr.schema_name().to_string()) + .collect::>(); + let required_indices = get_required_sort_exprs_indices( + sort.input.schema().as_ref(), + &sort_expr_names, + ); + + let unique_exprs = if required_indices.len() < unique_exprs.len() { + required_indices + .into_iter() + .map(|idx| unique_exprs[idx].clone()) + .collect() + } else { + unique_exprs + }; + let transformed = if len != unique_exprs.len() { Transformed::yes } else { Transformed::no }; + if unique_exprs.is_empty() { + return internal_err!( + "FD pruning unexpectedly removed all ORDER BY expressions" + ); + } + Ok(transformed(LogicalPlan::Sort(Sort { expr: unique_exprs, input: sort.input, @@ -130,7 +154,8 @@ mod tests { @ $expected:literal $(,)? ) => {{ let optimizer_ctx = OptimizerContext::new().with_max_passes(1); - let rules: Vec> = vec![Arc::new(EliminateDuplicatedExpr::new())]; + let rules: Vec> = + vec![Arc::new(EliminateDuplicatedExpr::new())]; assert_optimized_plan_eq_snapshot!( optimizer_ctx, rules, diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 488a32c7acde5..ffd48d5996576 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -260,6 +260,64 @@ physical_plan 02)--SortExec: expr=[c2@1 ASC NULLS LAST, c3@2 ASC NULLS LAST], preserve_partitioning=[false] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true + +# eliminate redundant fd sort expr +query TT +explain SELECT c2, SUM(c3) AS total_sal FROM aggregate_test_100 GROUP BY c2 ORDER BY c2, total_sal +---- +logical_plan +01)Sort: aggregate_test_100.c2 ASC NULLS LAST +02)--Projection: aggregate_test_100.c2, sum(aggregate_test_100.c3) AS total_sal +03)----Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[sum(CAST(aggregate_test_100.c3 AS Int64))]] +04)------TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +01)SortPreservingMergeExec: [c2@0 ASC NULLS LAST] +02)--SortExec: expr=[c2@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[c2@0 as c2, sum(aggregate_test_100.c3)@1 as total_sal] +04)------AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[sum(aggregate_test_100.c3)] +05)--------RepartitionExec: partitioning=Hash([c2@0], 4), input_partitions=4 +06)----------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[sum(aggregate_test_100.c3)] +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true + +# keep order by when dependency comes later +query TT +explain SELECT c2, SUM(c3) AS total_sal FROM aggregate_test_100 GROUP BY c2 ORDER BY total_sal, c2 +---- +logical_plan +01)Sort: total_sal ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST +02)--Projection: aggregate_test_100.c2, sum(aggregate_test_100.c3) AS total_sal +03)----Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[sum(CAST(aggregate_test_100.c3 AS Int64))]] +04)------TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +01)SortPreservingMergeExec: [total_sal@1 ASC NULLS LAST, c2@0 ASC NULLS LAST] +02)--SortExec: expr=[total_sal@1 ASC NULLS LAST, c2@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[c2@0 as c2, sum(aggregate_test_100.c3)@1 as total_sal] +04)------AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[sum(aggregate_test_100.c3)] +05)--------RepartitionExec: partitioning=Hash([c2@0], 4), input_partitions=4 +06)----------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[sum(aggregate_test_100.c3)] +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true + +# eliminate redundant sort expr even with non schema expr +query TT +explain SELECT c2, SUM(c3) AS total_sal FROM aggregate_test_100 GROUP BY c2 ORDER BY c2, total_sal, abs(c2) +---- +logical_plan +01)Sort: aggregate_test_100.c2 ASC NULLS LAST, abs(aggregate_test_100.c2) ASC NULLS LAST +02)--Projection: aggregate_test_100.c2, sum(aggregate_test_100.c3) AS total_sal +03)----Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[sum(CAST(aggregate_test_100.c3 AS Int64))]] +04)------TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +01)SortPreservingMergeExec: [c2@0 ASC NULLS LAST, abs(c2@0) ASC NULLS LAST] +02)--SortExec: expr=[c2@0 ASC NULLS LAST, abs(c2@0) ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[c2@0 as c2, sum(aggregate_test_100.c3)@1 as total_sal] +04)------AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[sum(aggregate_test_100.c3)] +05)--------RepartitionExec: partitioning=Hash([c2@0], 4), input_partitions=4 +06)----------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[sum(aggregate_test_100.c3)] +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true + query II SELECT c2, c3 FROM aggregate_test_100 ORDER BY c2, c3, c2 ---- @@ -1521,7 +1579,7 @@ query TT EXPLAIN SELECT c1, c2 FROM table_with_ordered_pk ORDER BY c1, c2; ---- logical_plan -01)Sort: table_with_ordered_pk.c1 ASC NULLS LAST, table_with_ordered_pk.c2 ASC NULLS LAST +01)Sort: table_with_ordered_pk.c1 ASC NULLS LAST 02)--TableScan: table_with_ordered_pk projection=[c1, c2] physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2], output_ordering=[c1@0 ASC NULLS LAST], constraints=[PrimaryKey([0])], file_type=csv, has_header=true diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index ae83045961488..caaf22f0adbd8 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -2425,11 +2425,9 @@ SELECT c9, rn1 FROM (SELECT c9, 145294611 96 # test_c9_rn_ordering_alias_opposite_direction3 -# These test check for whether datafusion is aware of the ordering of the column generated by ROW_NUMBER() window function. -# Physical plan should have a SortExec after BoundedWindowAggExec. -# The reason is that ordering of the table after BoundedWindowAggExec can be described as rn1 ASC, and also c9 DESC. -# However, the requirement is rn1 ASC, c9 ASC (lexicographical order). Hence existing ordering cannot satisfy requirement -# (Requirement is finer than existing ordering) +# These tests check whether DataFusion tracks the ordering of the column generated by ROW_NUMBER() window function. +# The outer ORDER BY can be simplified by ordering equivalence, so the plan should not need an additional SortExec +# beyond the one required to satisfy the window input order. query TT EXPLAIN SELECT c9, rn1 FROM (SELECT c9, ROW_NUMBER() OVER(ORDER BY c9 DESC) as rn1 @@ -2439,13 +2437,13 @@ EXPLAIN SELECT c9, rn1 FROM (SELECT c9, LIMIT 5 ---- logical_plan -01)Sort: rn1 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST, fetch=5 +01)Sort: rn1 ASC NULLS LAST, fetch=5 02)--Projection: aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 04)------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)SortExec: TopK(fetch=5), expr=[rn1@1 ASC NULLS LAST, c9@0 ASC NULLS LAST], preserve_partitioning=[false], sort_prefix=[rn1@1 ASC NULLS LAST] -02)--ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] +01)ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] +02)--GlobalLimitExec: skip=0, fetch=5 03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW": UInt64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -2514,7 +2512,7 @@ EXPLAIN SELECT c9, rn1 FROM (SELECT c9, LIMIT 5 ---- logical_plan -01)Sort: rn1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST, fetch=5 +01)Sort: rn1 ASC NULLS LAST, fetch=5 02)--Projection: aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 04)------TableScan: aggregate_test_100 projection=[c9]