Skip to content

Commit df32859

Browse files
authored
Enable inlist support for preimage (#20051)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #20050 ## Rationale for this change Check issue <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? Match arm to support preimage for InList expressions in expr_simplifier.rs <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes, added two tests for `IN LIST` and `NOT IN LIST` support. <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 2c54cde commit df32859

4 files changed

Lines changed: 136 additions & 2 deletions

File tree

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,53 @@ impl TreeNodeRewriter for Simplifier<'_> {
20442044
Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right }))
20452045
}
20462046
}
2047+
// For case:
2048+
// date_part('YEAR', expr) IN (literal1, literal2, ...)
2049+
Expr::InList(InList {
2050+
expr,
2051+
list,
2052+
negated,
2053+
}) => {
2054+
if list.len() > THRESHOLD_INLINE_INLIST || list.iter().any(is_null) {
2055+
return Ok(Transformed::no(Expr::InList(InList {
2056+
expr,
2057+
list,
2058+
negated,
2059+
})));
2060+
}
2061+
2062+
let (op, combiner): (Operator, fn(Expr, Expr) -> Expr) =
2063+
if negated { (NotEq, and) } else { (Eq, or) };
2064+
2065+
let mut rewritten: Option<Expr> = None;
2066+
for item in &list {
2067+
let PreimageResult::Range { interval, expr } =
2068+
get_preimage(expr.as_ref(), item, info)?
2069+
else {
2070+
return Ok(Transformed::no(Expr::InList(InList {
2071+
expr,
2072+
list,
2073+
negated,
2074+
})));
2075+
};
2076+
2077+
let range_expr = rewrite_with_preimage(*interval, op, expr)?.data;
2078+
rewritten = Some(match rewritten {
2079+
None => range_expr,
2080+
Some(acc) => combiner(acc, range_expr),
2081+
});
2082+
}
2083+
2084+
if let Some(rewritten) = rewritten {
2085+
Transformed::yes(rewritten)
2086+
} else {
2087+
Transformed::no(Expr::InList(InList {
2088+
expr,
2089+
list,
2090+
negated,
2091+
}))
2092+
}
2093+
}
20472094

20482095
// no additional rewrites possible
20492096
expr => Transformed::no(expr),

datafusion/optimizer/src/simplify_expressions/udf_preimage.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ mod test {
7575
use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue};
7676
use datafusion_expr::{
7777
ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
78-
Signature, Volatility, and, binary_expr, col, lit, preimage::PreimageResult,
78+
Signature, Volatility, and, binary_expr, col, lit, or, preimage::PreimageResult,
7979
simplify::SimplifyContext,
8080
};
8181

@@ -164,6 +164,15 @@ mod test {
164164
)?),
165165
})
166166
}
167+
Expr::Literal(ScalarValue::Int32(Some(600)), _) => {
168+
Ok(PreimageResult::Range {
169+
expr,
170+
interval: Box::new(Interval::try_new(
171+
ScalarValue::Int32(Some(300)),
172+
ScalarValue::Int32(Some(400)),
173+
)?),
174+
})
175+
}
167176
_ => Ok(PreimageResult::None),
168177
}
169178
}
@@ -311,6 +320,38 @@ mod test {
311320
assert_eq!(optimize_test(expr, &schema), expected);
312321
}
313322

323+
#[test]
324+
fn test_preimage_in_list_rewrite() {
325+
let schema = test_schema();
326+
let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], false);
327+
let expected = or(
328+
and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))),
329+
and(col("x").gt_eq(lit(300)), col("x").lt(lit(400))),
330+
);
331+
332+
assert_eq!(optimize_test(expr, &schema), expected);
333+
}
334+
335+
#[test]
336+
fn test_preimage_not_in_list_rewrite() {
337+
let schema = test_schema();
338+
let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], true);
339+
let expected = and(
340+
or(col("x").lt(lit(100)), col("x").gt_eq(lit(200))),
341+
or(col("x").lt(lit(300)), col("x").gt_eq(lit(400))),
342+
);
343+
344+
assert_eq!(optimize_test(expr, &schema), expected);
345+
}
346+
347+
#[test]
348+
fn test_preimage_in_list_long_list_no_rewrite() {
349+
let schema = test_schema();
350+
let expr = preimage_udf_expr().in_list((1..100).map(lit).collect(), false);
351+
352+
assert_eq!(optimize_test(expr.clone(), &schema), expr);
353+
}
354+
314355
#[test]
315356
fn test_preimage_non_literal_rhs_no_rewrite() {
316357
// Non-literal RHS should not be rewritten.

datafusion/sqllogictest/test_files/datetime/date_part.slt

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,19 @@ NULL
12471247
1990-01-01
12481248
2030-01-01
12491249

1250+
# IN list optimization
1251+
query D
1252+
select c from t1 where extract(year from c) in (1990, 2024);
1253+
----
1254+
1990-01-01
1255+
2024-01-01
1256+
1257+
# NOT IN list optimization (NULL does not satisfy NOT IN)
1258+
query D
1259+
select c from t1 where extract(year from c) not in (1990, 2024);
1260+
----
1261+
2030-01-01
1262+
12501263
# Check that date_part is not in the explain statements
12511264

12521265
query TT
@@ -1329,6 +1342,16 @@ physical_plan
13291342
01)FilterExec: c@0 < 2024-01-01 OR c@0 >= 2025-01-01 OR c@0 IS NULL
13301343
02)--DataSourceExec: partitions=1, partition_sizes=[1]
13311344

1345+
query TT
1346+
explain select c from t1 where extract (year from c) in (1990, 2024)
1347+
----
1348+
logical_plan
1349+
01)Filter: t1.c >= Date32("1990-01-01") AND t1.c < Date32("1991-01-01") OR t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01")
1350+
02)--TableScan: t1 projection=[c]
1351+
physical_plan
1352+
01)FilterExec: c@0 >= 1990-01-01 AND c@0 < 1991-01-01 OR c@0 >= 2024-01-01 AND c@0 < 2025-01-01
1353+
02)--DataSourceExec: partitions=1, partition_sizes=[1]
1354+
13321355
# Simple optimizations, column on RHS
13331356

13341357
query D
@@ -1730,4 +1753,4 @@ logical_plan
17301753
02)--TableScan: t1 projection=[c]
17311754
physical_plan
17321755
01)FilterExec: c@0 >= 2024-01-01 AND c@0 < 2025-01-01
1733-
02)--DataSourceExec: partitions=1, partition_sizes=[1]
1756+
02)--DataSourceExec: partitions=1, partition_sizes=[1]

datafusion/sqllogictest/test_files/floor_preimage.slt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,21 @@ query I rowsort
104104
SELECT id FROM test_data WHERE floor(float_val) = arrow_cast(5.5, 'Float64');
105105
----
106106

107+
# IN list: floor(x) IN (5, 7) matches [5.0, 6.0) and [7.0, 8.0)
108+
query I rowsort
109+
SELECT id FROM test_data WHERE floor(float_val) IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64'));
110+
----
111+
1
112+
2
113+
5
114+
115+
# NOT IN list: floor(x) NOT IN (5, 7) excludes matching ranges and NULLs
116+
query I rowsort
117+
SELECT id FROM test_data WHERE floor(float_val) NOT IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64'));
118+
----
119+
3
120+
4
121+
107122
##########
108123
## EXPLAIN Tests - Plan Optimization
109124
##########
@@ -177,6 +192,14 @@ logical_plan
177192
01)Filter: floor(test_data.float_val) = Float64(9007199254740992)
178193
02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val]
179194

195+
# 9. IN list: each list item is rewritten with preimage and OR-ed together
196+
query TT
197+
EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64'));
198+
----
199+
logical_plan
200+
01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) OR test_data.float_val >= Float64(7) AND test_data.float_val < Float64(8)
201+
02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val]
202+
180203
# Data correctness: floor(col) = 2^53 returns no rows (no value in test_data has floor exactly 2^53)
181204
query I rowsort
182205
SELECT id FROM test_data WHERE floor(float_val) = 9007199254740992;

0 commit comments

Comments
 (0)