Skip to content

Commit b5d6c25

Browse files
committed
Actually preserve predicate execution order in PushDownFilter
1 parent edf8ad3 commit b5d6c25

2 files changed

Lines changed: 53 additions & 13 deletions

File tree

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -810,25 +810,22 @@ impl OptimizerRule for PushDownFilter {
810810

811811
match Arc::unwrap_or_clone(filter.input) {
812812
LogicalPlan::Filter(child_filter) => {
813-
let parents_predicates = split_conjunction_owned(filter.predicate);
814-
815-
// remove duplicated filters
816-
let child_predicates = split_conjunction_owned(child_filter.predicate);
817-
let new_predicates = parents_predicates
818-
.into_iter()
819-
.chain(child_predicates)
820-
// use IndexSet to remove dupes while preserving predicate order
821-
.collect::<IndexSet<_>>()
813+
// child filters first to preserve execution order
814+
let new_predicates = split_conjunction_owned(child_filter.predicate)
822815
.into_iter()
823-
.collect::<Vec<_>>();
816+
.chain(split_conjunction_owned(filter.predicate))
817+
// use IndexSet to remove duplicates while preserving predicate order
818+
.collect::<IndexSet<_>>();
824819

825820
let Some(new_predicate) = conjunction(new_predicates) else {
826821
return plan_err!("at least one expression exists");
827822
};
823+
828824
let new_filter = LogicalPlan::Filter(Filter::try_new(
829825
new_predicate,
830826
child_filter.input,
831827
)?);
828+
832829
self.rewrite(new_filter, config)
833830
}
834831
LogicalPlan::Repartition(repartition) => {
@@ -2474,7 +2471,7 @@ mod tests {
24742471
plan,
24752472
@r"
24762473
Projection: test.a
2477-
Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2474+
Filter: test.a <= Int64(1) AND test.a >= Int64(1)
24782475
Limit: skip=0, fetch=1
24792476
TableScan: test
24802477
"
@@ -3253,6 +3250,28 @@ mod tests {
32533250
)
32543251
}
32553252

3253+
#[test]
3254+
fn multi_combined_two_filters() -> Result<()> {
3255+
let plan = table_scan_with_pushdown_provider_builder(
3256+
TableProviderFilterPushDown::Inexact,
3257+
vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3258+
Some(vec![0]),
3259+
)?
3260+
.filter(col("a").eq(lit(10i64)))?
3261+
.filter(col("b").gt(lit(11i64)))?
3262+
.project(vec![col("a"), col("b")])?
3263+
.build()?;
3264+
3265+
assert_optimized_plan_equal!(
3266+
plan,
3267+
@r"
3268+
Projection: a, b
3269+
Filter: a = Int64(10) AND b > Int64(11)
3270+
TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3271+
"
3272+
)
3273+
}
3274+
32563275
#[test]
32573276
fn multi_combined_filter_exact() -> Result<()> {
32583277
let plan = table_scan_with_pushdown_provider_builder(
@@ -3273,6 +3292,27 @@ mod tests {
32733292
)
32743293
}
32753294

3295+
#[test]
3296+
fn multi_combined_two_filters_exact() -> Result<()> {
3297+
let plan = table_scan_with_pushdown_provider_builder(
3298+
TableProviderFilterPushDown::Exact,
3299+
vec![],
3300+
Some(vec![0]),
3301+
)?
3302+
.filter(col("a").eq(lit(10i64)))?
3303+
.filter(col("b").gt(lit(11i64)))?
3304+
.project(vec![col("a"), col("b")])?
3305+
.build()?;
3306+
3307+
assert_optimized_plan_equal!(
3308+
plan,
3309+
@r"
3310+
Projection: a, b
3311+
TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3312+
"
3313+
)
3314+
}
3315+
32763316
#[test]
32773317
fn test_filter_with_alias() -> Result<()> {
32783318
// in table scan the true col name is 'test.a',

datafusion/sqllogictest/test_files/predicates.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ logical_plan
665665
02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
666666
03)----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)
667667
04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
668-
05)----Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
668+
05)----Filter: part.p_size >= Int32(1) AND (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15))
669669
06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)]
670670
physical_plan
671671
01)HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_partkey@0]
@@ -674,7 +674,7 @@ physical_plan
674674
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
675675
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_partkey, l_quantity], file_type=csv, has_header=true
676676
06)--RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4
677-
07)----FilterExec: (p_brand@1 = Brand#12 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_size@2 <= 15) AND p_size@2 >= 1
677+
07)----FilterExec: p_size@2 >= 1 AND (p_brand@1 = Brand#12 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_size@2 <= 15)
678678
08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
679679
09)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/part.csv]]}, projection=[p_partkey, p_brand, p_size], file_type=csv, has_header=true
680680

0 commit comments

Comments
 (0)