Skip to content

Commit fbd720e

Browse files
committed
Actually preserve predicate execution order in PushDownFilter
1 parent d59bc72 commit fbd720e

4 files changed

Lines changed: 58 additions & 18 deletions

File tree

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -813,25 +813,22 @@ impl OptimizerRule for PushDownFilter {
813813

814814
match Arc::unwrap_or_clone(filter.input) {
815815
LogicalPlan::Filter(child_filter) => {
816-
let parents_predicates = split_conjunction_owned(filter.predicate);
817-
818-
// remove duplicated filters
819-
let child_predicates = split_conjunction_owned(child_filter.predicate);
820-
let new_predicates = parents_predicates
821-
.into_iter()
822-
.chain(child_predicates)
823-
// use IndexSet to remove dupes while preserving predicate order
824-
.collect::<IndexSet<_>>()
816+
// child filters first to preserve execution order
817+
let new_predicates = split_conjunction_owned(child_filter.predicate)
825818
.into_iter()
826-
.collect::<Vec<_>>();
819+
.chain(split_conjunction_owned(filter.predicate))
820+
// use IndexSet to remove duplicates while preserving predicate order
821+
.collect::<IndexSet<_>>();
827822

828823
let Some(new_predicate) = conjunction(new_predicates) else {
829824
return plan_err!("at least one expression exists");
830825
};
826+
831827
let new_filter = LogicalPlan::Filter(Filter::try_new(
832828
new_predicate,
833829
child_filter.input,
834830
)?);
831+
835832
self.rewrite(new_filter, config)
836833
}
837834
LogicalPlan::Repartition(repartition) => {
@@ -2486,7 +2483,7 @@ mod tests {
24862483
plan,
24872484
@r"
24882485
Projection: test.a
2489-
Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2486+
Filter: test.a <= Int64(1) AND test.a >= Int64(1)
24902487
Limit: skip=0, fetch=1
24912488
TableScan: test
24922489
"
@@ -3261,6 +3258,28 @@ mod tests {
32613258
)
32623259
}
32633260

3261+
#[test]
3262+
fn multi_combined_two_filters() -> Result<()> {
3263+
let plan = table_scan_with_pushdown_provider_builder(
3264+
TableProviderFilterPushDown::Inexact,
3265+
vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3266+
Some(vec![0]),
3267+
)?
3268+
.filter(col("a").eq(lit(10i64)))?
3269+
.filter(col("b").gt(lit(11i64)))?
3270+
.project(vec![col("a"), col("b")])?
3271+
.build()?;
3272+
3273+
assert_optimized_plan_equal!(
3274+
plan,
3275+
@r"
3276+
Projection: a, b
3277+
Filter: a = Int64(10) AND b > Int64(11)
3278+
TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3279+
"
3280+
)
3281+
}
3282+
32643283
#[test]
32653284
fn multi_combined_filter_exact() -> Result<()> {
32663285
let plan = table_scan_with_pushdown_provider_builder(
@@ -3281,6 +3300,27 @@ mod tests {
32813300
)
32823301
}
32833302

3303+
#[test]
3304+
fn multi_combined_two_filters_exact() -> Result<()> {
3305+
let plan = table_scan_with_pushdown_provider_builder(
3306+
TableProviderFilterPushDown::Exact,
3307+
vec![],
3308+
Some(vec![0]),
3309+
)?
3310+
.filter(col("a").eq(lit(10i64)))?
3311+
.filter(col("b").gt(lit(11i64)))?
3312+
.project(vec![col("a"), col("b")])?
3313+
.build()?;
3314+
3315+
assert_optimized_plan_equal!(
3316+
plan,
3317+
@r"
3318+
Projection: a, b
3319+
TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3320+
"
3321+
)
3322+
}
3323+
32843324
#[test]
32853325
fn test_filter_with_alias() -> Result<()> {
32863326
// in table scan the true col name is 'test.a',

datafusion/sqllogictest/test_files/predicates.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ logical_plan
665665
02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
666666
03)----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)
667667
04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
668-
05)----Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
668+
05)----Filter: part.p_size >= Int32(1) AND (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15))
669669
06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)]
670670
physical_plan
671671
01)HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_partkey@0]
@@ -674,7 +674,7 @@ physical_plan
674674
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
675675
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_partkey, l_quantity], file_type=csv, has_header=true
676676
06)--RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4
677-
07)----FilterExec: (p_brand@1 = Brand#12 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_size@2 <= 15) AND p_size@2 >= 1
677+
07)----FilterExec: p_size@2 >= 1 AND (p_brand@1 = Brand#12 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_size@2 <= 15)
678678
08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
679679
09)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/part.csv]]}, projection=[p_partkey, p_brand, p_size], file_type=csv, has_header=true
680680

datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ logical_plan
5959
03)----Projection: lineitem.l_extendedprice, lineitem.l_discount
6060
04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
6161
05)--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount
62-
06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON")
62+
06)----------Filter: (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON") AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2))
6363
07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG"), lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
64-
08)--------Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
64+
08)--------Filter: part.p_size >= Int32(1) AND (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15))
6565
09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)]
6666
physical_plan
6767
01)ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as revenue]
@@ -70,9 +70,9 @@ physical_plan
7070
04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
7171
05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3]
7272
06)----------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4
73-
07)------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON, projection=[l_partkey@0, l_quantity@1, l_extendedprice@2, l_discount@3]
73+
07)------------FilterExec: (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON AND (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2), projection=[l_partkey@0, l_quantity@1, l_extendedprice@2, l_discount@3]
7474
08)--------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], file_type=csv, has_header=false
7575
09)----------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4
76-
10)------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15) AND p_size@2 >= 1
76+
10)------------FilterExec: p_size@2 >= 1 AND (p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15)
7777
11)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
7878
12)----------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], file_type=csv, has_header=false

0 commit comments

Comments
 (0)