Skip to content

Commit 336af3d

Browse files
committed
Actually preserve predicate execution order in PushDownFilter
1 parent f99ba69 commit 336af3d

3 files changed

Lines changed: 57 additions & 17 deletions

File tree

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -810,25 +810,22 @@ impl OptimizerRule for PushDownFilter {
810810

811811
match Arc::unwrap_or_clone(filter.input) {
812812
LogicalPlan::Filter(child_filter) => {
813-
let parents_predicates = split_conjunction_owned(filter.predicate);
814-
815-
// remove duplicated filters
816-
let child_predicates = split_conjunction_owned(child_filter.predicate);
817-
let new_predicates = parents_predicates
818-
.into_iter()
819-
.chain(child_predicates)
820-
// use IndexSet to remove dupes while preserving predicate order
821-
.collect::<IndexSet<_>>()
813+
// child filters first to preserve execution order
814+
let new_predicates = split_conjunction_owned(child_filter.predicate)
822815
.into_iter()
823-
.collect::<Vec<_>>();
816+
.chain(split_conjunction_owned(filter.predicate))
817+
// use IndexSet to remove duplicates while preserving predicate order
818+
.collect::<IndexSet<_>>();
824819

825820
let Some(new_predicate) = conjunction(new_predicates) else {
826821
return plan_err!("at least one expression exists");
827822
};
823+
828824
let new_filter = LogicalPlan::Filter(Filter::try_new(
829825
new_predicate,
830826
child_filter.input,
831827
)?);
828+
832829
self.rewrite(new_filter, config)
833830
}
834831
LogicalPlan::Repartition(repartition) => {
@@ -2474,7 +2471,7 @@ mod tests {
24742471
plan,
24752472
@r"
24762473
Projection: test.a
2477-
Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2474+
Filter: test.a <= Int64(1) AND test.a >= Int64(1)
24782475
Limit: skip=0, fetch=1
24792476
TableScan: test
24802477
"
@@ -3253,6 +3250,28 @@ mod tests {
32533250
)
32543251
}
32553252

3253+
#[test]
3254+
fn multi_combined_two_filters() -> Result<()> {
3255+
let plan = table_scan_with_pushdown_provider_builder(
3256+
TableProviderFilterPushDown::Inexact,
3257+
vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3258+
Some(vec![0]),
3259+
)?
3260+
.filter(col("a").eq(lit(10i64)))?
3261+
.filter(col("b").gt(lit(11i64)))?
3262+
.project(vec![col("a"), col("b")])?
3263+
.build()?;
3264+
3265+
assert_optimized_plan_equal!(
3266+
plan,
3267+
@r"
3268+
Projection: a, b
3269+
Filter: a = Int64(10) AND b > Int64(11)
3270+
TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3271+
"
3272+
)
3273+
}
3274+
32563275
#[test]
32573276
fn multi_combined_filter_exact() -> Result<()> {
32583277
let plan = table_scan_with_pushdown_provider_builder(
@@ -3273,6 +3292,27 @@ mod tests {
32733292
)
32743293
}
32753294

3295+
#[test]
3296+
fn multi_combined_two_filters_exact() -> Result<()> {
3297+
let plan = table_scan_with_pushdown_provider_builder(
3298+
TableProviderFilterPushDown::Exact,
3299+
vec![],
3300+
Some(vec![0]),
3301+
)?
3302+
.filter(col("a").eq(lit(10i64)))?
3303+
.filter(col("b").gt(lit(11i64)))?
3304+
.project(vec![col("a"), col("b")])?
3305+
.build()?;
3306+
3307+
assert_optimized_plan_equal!(
3308+
plan,
3309+
@r"
3310+
Projection: a, b
3311+
TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3312+
"
3313+
)
3314+
}
3315+
32763316
#[test]
32773317
fn test_filter_with_alias() -> Result<()> {
32783318
// in table scan the true col name is 'test.a',

datafusion/sqllogictest/test_files/predicates.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ logical_plan
665665
02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
666666
03)----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)
667667
04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
668-
05)----Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
668+
05)----Filter: part.p_size >= Int32(1) AND (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15))
669669
06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)]
670670
physical_plan
671671
01)HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_partkey@0]
@@ -674,7 +674,7 @@ physical_plan
674674
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
675675
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_partkey, l_quantity], file_type=csv, has_header=true
676676
06)--RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4
677-
07)----FilterExec: (p_brand@1 = Brand#12 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_size@2 <= 15) AND p_size@2 >= 1
677+
07)----FilterExec: p_size@2 >= 1 AND (p_brand@1 = Brand#12 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_size@2 <= 15)
678678
08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
679679
09)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/part.csv]]}, projection=[p_partkey, p_brand, p_size], file_type=csv, has_header=true
680680

datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ logical_plan
5959
03)----Projection: lineitem.l_extendedprice, lineitem.l_discount
6060
04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
6161
05)--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount
62-
06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON")
62+
06)----------Filter: (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON") AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2))
6363
07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG"), lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
64-
08)--------Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
64+
08)--------Filter: part.p_size >= Int32(1) AND (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15))
6565
09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)]
6666
physical_plan
6767
01)ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as revenue]
@@ -70,9 +70,9 @@ physical_plan
7070
04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
7171
05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3]
7272
06)----------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4
73-
07)------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON, projection=[l_partkey@0, l_quantity@1, l_extendedprice@2, l_discount@3]
73+
07)------------FilterExec: (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON AND (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2), projection=[l_partkey@0, l_quantity@1, l_extendedprice@2, l_discount@3]
7474
08)--------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], file_type=csv, has_header=false
7575
09)----------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4
76-
10)------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15) AND p_size@2 >= 1
76+
10)------------FilterExec: p_size@2 >= 1 AND (p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15)
7777
11)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
7878
12)----------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], file_type=csv, has_header=false

0 commit comments

Comments
 (0)