Skip to content

Commit 4bff17e

Browse files
authored
[Minor]: unify ANY/ALL planning and align ANY NULL semantics with PG (#21743)
## Which issue does this PR close? Related with #2547 and #2548 but does not close ## Rationale for this change In #21416 I've aligned ALL operator NULL semantics to Postgres while supporting additional operators. I've implemented ANY operator but missed that part. Initially that PR included these changes but with the suggestion from @Jefffrey I've separated those changes and opened this PR. ## What changes are included in this PR? - Refactor ANY operator to use same logic with ALL operator - align null semantics with postgres | Query | PostgreSQL | This PR | DuckDB | |---|---|---|---| | `5 = ANY(NULL::INT[])` | `NULL` | `NULL` | `false` | | `5 <> ANY(NULL::INT[])` | `NULL` | `NULL` | `false` | | `5 > ANY(NULL::INT[])` | `NULL` | `NULL` | `false` | | `5 < ANY(NULL::INT[])` | `NULL` | `NULL` | `false` | | `5 >= ANY(NULL::INT[])` | `NULL` | `NULL` | `false` | | `5 <= ANY(NULL::INT[])` | `NULL` | `NULL` | `false` | #### I'll explore a followup implementation on ANY and ALL as UDFs instead of this case approach to see if it will perform faster. I've wanted to open this PR to correct out the NULL behavior ## Are these changes tested? Yes existing and additional slt tests. ## Are there any user-facing changes? Yes users will be able to see null semantics are same as postgres
1 parent ff844be commit 4bff17e

3 files changed

Lines changed: 104 additions & 90 deletions

File tree

datafusion/sql/src/expr/mod.rs

Lines changed: 50 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,12 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
621621
_ => {
622622
let left_expr = self.sql_to_expr(*left, schema, planner_context)?;
623623
let right_expr = self.sql_to_expr(*right, schema, planner_context)?;
624-
plan_any_op(left_expr, right_expr, &compare_op)
624+
plan_quantified_op(
625+
&left_expr,
626+
&right_expr,
627+
&compare_op,
628+
SetQuantifier::Any,
629+
)
625630
}
626631
},
627632
SQLExpr::AllOp {
@@ -640,7 +645,12 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
640645
_ => {
641646
let left_expr = self.sql_to_expr(*left, schema, planner_context)?;
642647
let right_expr = self.sql_to_expr(*right, schema, planner_context)?;
643-
plan_all_op(&left_expr, &right_expr, &compare_op)
648+
plan_quantified_op(
649+
&left_expr,
650+
&right_expr,
651+
&compare_op,
652+
SetQuantifier::All,
653+
)
644654
}
645655
},
646656
#[expect(deprecated)]
@@ -1249,73 +1259,20 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
12491259
}
12501260
}
12511261

1252-
/// Builds a CASE expression that handles NULL semantics for `x <op> ANY(arr)`:
1253-
///
1254-
/// ```text
1255-
/// CASE
1256-
/// WHEN <min_or_max>(arr) IS NOT NULL THEN <comparison>
1257-
/// WHEN arr IS NOT NULL THEN FALSE -- empty or all-null array
1258-
/// ELSE NULL -- NULL array
1259-
/// END
1260-
/// ```
1261-
fn any_op_with_null_handling(bound: Expr, comparison: Expr, arr: Expr) -> Result<Expr> {
1262-
when(bound.is_not_null(), comparison)
1263-
.when(arr.is_not_null(), lit(false))
1264-
.otherwise(lit(ScalarValue::Boolean(None)))
1265-
}
1266-
1267-
/// Plans a `<left> <op> ANY(<right>)` expression for non-subquery operands.
1268-
fn plan_any_op(
1269-
left_expr: Expr,
1270-
right_expr: Expr,
1271-
compare_op: &BinaryOperator,
1272-
) -> Result<Expr> {
1273-
match compare_op {
1274-
BinaryOperator::Eq => Ok(array_has(right_expr, left_expr)),
1275-
BinaryOperator::NotEq => {
1276-
let min = array_min(right_expr.clone());
1277-
let max = array_max(right_expr.clone());
1278-
// NOT EQ is true when either bound differs from left
1279-
let comparison = min
1280-
.not_eq(left_expr.clone())
1281-
.or(max.clone().not_eq(left_expr));
1282-
any_op_with_null_handling(max, comparison, right_expr)
1283-
}
1284-
BinaryOperator::Gt => {
1285-
let min = array_min(right_expr.clone());
1286-
any_op_with_null_handling(min.clone(), min.lt(left_expr), right_expr)
1287-
}
1288-
BinaryOperator::Lt => {
1289-
let max = array_max(right_expr.clone());
1290-
any_op_with_null_handling(max.clone(), max.gt(left_expr), right_expr)
1291-
}
1292-
BinaryOperator::GtEq => {
1293-
let min = array_min(right_expr.clone());
1294-
any_op_with_null_handling(min.clone(), min.lt_eq(left_expr), right_expr)
1295-
}
1296-
BinaryOperator::LtEq => {
1297-
let max = array_max(right_expr.clone());
1298-
any_op_with_null_handling(max.clone(), max.gt_eq(left_expr), right_expr)
1299-
}
1300-
_ => plan_err!(
1301-
"Unsupported AnyOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported"
1302-
),
1303-
}
1304-
}
1305-
1306-
/// Plans `needle <compare_op> ALL(haystack)` with proper SQL NULL semantics.
1262+
/// Plans `needle <compare_op> ANY/ALL(haystack)` with proper SQL NULL semantics.
13071263
///
13081264
/// CASE/WHEN structure:
13091265
/// WHEN arr IS NULL → NULL
1310-
/// WHEN empty → TRUE
1266+
/// WHEN empty → vacuous_result (ANY:false, ALL:true)
13111267
/// WHEN lhs IS NULL → NULL
1312-
/// WHEN decisive_condition → FALSE
1268+
/// WHEN decisive_condition → decisive_result (ANY:true match found, ALL:false violation found)
13131269
/// WHEN has_nulls → NULL
1314-
/// ELSE → TRUE
1315-
fn plan_all_op(
1270+
/// ELSE → vacuous_result
1271+
fn plan_quantified_op(
13161272
needle: &Expr,
13171273
haystack: &Expr,
13181274
compare_op: &BinaryOperator,
1275+
quantifier: SetQuantifier,
13191276
) -> Result<Expr> {
13201277
let null_arr_check = haystack.clone().is_null();
13211278
let empty_check = cardinality(haystack.clone()).eq(lit(0u64));
@@ -1325,40 +1282,61 @@ fn plan_all_op(
13251282
let has_nulls =
13261283
array_position(haystack.clone(), lit(ScalarValue::Null), lit(1i64)).is_not_null();
13271284

1328-
let decisive_condition = match compare_op {
1329-
BinaryOperator::NotEq => array_has(haystack.clone(), needle.clone()),
1330-
BinaryOperator::Eq => {
1285+
let decisive_condition = match (compare_op, quantifier) {
1286+
(BinaryOperator::Eq, SetQuantifier::Any)
1287+
| (BinaryOperator::NotEq, SetQuantifier::All) => {
1288+
array_has(haystack.clone(), needle.clone())
1289+
}
1290+
(BinaryOperator::Eq, SetQuantifier::All)
1291+
| (BinaryOperator::NotEq, SetQuantifier::Any) => {
13311292
let all_equal = array_min(haystack.clone())
13321293
.eq(needle.clone())
13331294
.and(array_max(haystack.clone()).eq(needle.clone()));
13341295
Expr::Not(Box::new(all_equal))
13351296
}
1336-
BinaryOperator::Gt => {
1297+
(BinaryOperator::Gt, SetQuantifier::Any) => {
1298+
needle.clone().gt(array_min(haystack.clone()))
1299+
}
1300+
(BinaryOperator::Gt, SetQuantifier::All) => {
13371301
Expr::Not(Box::new(needle.clone().gt(array_max(haystack.clone()))))
13381302
}
1339-
BinaryOperator::Lt => {
1303+
(BinaryOperator::Lt, SetQuantifier::Any) => {
1304+
needle.clone().lt(array_max(haystack.clone()))
1305+
}
1306+
(BinaryOperator::Lt, SetQuantifier::All) => {
13401307
Expr::Not(Box::new(needle.clone().lt(array_min(haystack.clone()))))
13411308
}
1342-
BinaryOperator::GtEq => {
1309+
(BinaryOperator::GtEq, SetQuantifier::Any) => {
1310+
needle.clone().gt_eq(array_min(haystack.clone()))
1311+
}
1312+
(BinaryOperator::GtEq, SetQuantifier::All) => {
13431313
Expr::Not(Box::new(needle.clone().gt_eq(array_max(haystack.clone()))))
13441314
}
1345-
BinaryOperator::LtEq => {
1315+
(BinaryOperator::LtEq, SetQuantifier::Any) => {
1316+
needle.clone().lt_eq(array_max(haystack.clone()))
1317+
}
1318+
(BinaryOperator::LtEq, SetQuantifier::All) => {
13461319
Expr::Not(Box::new(needle.clone().lt_eq(array_min(haystack.clone()))))
13471320
}
13481321
_ => {
13491322
return plan_err!(
1350-
"Unsupported AllOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported"
1323+
"Unsupported {quantifier}Op: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported"
13511324
);
13521325
}
13531326
};
13541327

1328+
let (vacuous_result, decisive_result) = match quantifier {
1329+
SetQuantifier::Any => (false, true),
1330+
SetQuantifier::All => (true, false),
1331+
};
1332+
13551333
let null_bool = lit(ScalarValue::Boolean(None));
13561334
when(null_arr_check, null_bool.clone())
1357-
.when(empty_check, lit(true))
1335+
.when(empty_check, lit(vacuous_result))
13581336
.when(null_lhs_check, null_bool.clone())
1359-
.when(decisive_condition, lit(false))
1337+
.when(decisive_condition, lit(decisive_result))
13601338
.when(has_nulls, null_bool)
1361-
.otherwise(lit(true))
1339+
.otherwise(lit(vacuous_result))
13621340
}
13631341

13641342
#[cfg(test)]

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ fn roundtrip_statement_postgres_any_array_expr() -> Result<(), DataFusionError>
368368
sql: "select left from array where 1 = any(left);",
369369
parser_dialect: GenericDialect {},
370370
unparser_dialect: UnparserPostgreSqlDialect {},
371-
expected: @r#"SELECT "array"."left" FROM "array" WHERE 1 = ANY("array"."left")"#,
371+
expected: @r#"SELECT "array"."left" FROM "array" WHERE CASE WHEN "array"."left" IS NULL THEN NULL WHEN (cardinality("array"."left") = 0) THEN false WHEN 1 IS NULL THEN NULL WHEN 1 = ANY("array"."left") THEN true WHEN array_position("array"."left", NULL, 1) IS NOT NULL THEN NULL ELSE false END"#,
372372
);
373373
Ok(())
374374
}

datafusion/sqllogictest/test_files/array/array_has.slt

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -517,16 +517,18 @@ logical_plan
517517
03)----SubqueryAlias: test
518518
04)------SubqueryAlias: t
519519
05)--------Projection:
520-
06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
521-
07)------------TableScan: generate_series() projection=[value]
520+
06)----------Filter: __common_expr_3 IS NULL AND Boolean(NULL) OR __common_expr_3 IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) IS NOT DISTINCT FROM Boolean(true) AND __common_expr_3 IS NOT NULL
521+
07)------------Projection: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) AS __common_expr_3
522+
08)--------------TableScan: generate_series() projection=[value]
522523
physical_plan
523524
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
524525
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
525526
03)----CoalescePartitionsExec
526527
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
527-
05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[]
528-
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
529-
07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
528+
05)--------FilterExec: __common_expr_3@0 IS NULL AND NULL OR __common_expr_3@0 IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) IS NOT DISTINCT FROM true AND __common_expr_3@0 IS NOT NULL, projection=[]
529+
06)----------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8View)), 1, 32) as __common_expr_3]
530+
07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
531+
08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
530532

531533
query I
532534
with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
@@ -754,26 +756,26 @@ select 5 <= any(make_array());
754756
false
755757

756758
# Mixed NULL + non-NULL array where no non-NULL element satisfies the condition
757-
# These return false (NULLs are skipped by array_min/array_max)
759+
# These return NULL because NULLs leave the result indeterminate
758760
query B
759761
select 5 > any(make_array(6, NULL));
760762
----
761-
false
763+
NULL
762764

763765
query B
764766
select 5 < any(make_array(3, NULL));
765767
----
766-
false
768+
NULL
767769

768770
query B
769771
select 5 >= any(make_array(6, NULL));
770772
----
771-
false
773+
NULL
772774

773775
query B
774776
select 5 <= any(make_array(3, NULL));
775777
----
776-
false
778+
NULL
777779

778780
# Mixed NULL + non-NULL array where a non-NULL element satisfies the condition
779781
query B
@@ -804,33 +806,38 @@ true
804806
query B
805807
select 5 <> any(make_array(5, NULL));
806808
----
807-
false
809+
NULL
808810

809-
# All-NULL array: all operators should return false
811+
# All-NULL array: all operators should return NULL (unknown comparison)
810812
query B
811813
select 5 > any(make_array(NULL::INT, NULL::INT));
812814
----
813-
false
815+
NULL
814816

815817
query B
816818
select 5 < any(make_array(NULL::INT, NULL::INT));
817819
----
818-
false
820+
NULL
819821

820822
query B
821823
select 5 >= any(make_array(NULL::INT, NULL::INT));
822824
----
823-
false
825+
NULL
824826

825827
query B
826828
select 5 <= any(make_array(NULL::INT, NULL::INT));
827829
----
828-
false
830+
NULL
829831

830832
query B
831833
select 5 <> any(make_array(NULL::INT, NULL::INT));
832834
----
833-
false
835+
NULL
836+
837+
query B
838+
select 5 = any(make_array(NULL::INT, NULL::INT));
839+
----
840+
NULL
834841

835842
# NULL left operand: should return NULL for non-empty arrays
836843
query B
@@ -890,6 +897,35 @@ select 5 <> any(NULL::INT[]);
890897
----
891898
NULL
892899

900+
query B
901+
select 5 = any(NULL::INT[]);
902+
----
903+
NULL
904+
905+
# NULL = ANY with non-empty array
906+
query B
907+
select NULL = any(make_array(1, 2, 3));
908+
----
909+
NULL
910+
911+
# = ANY with no match, no NULLs
912+
query B
913+
select 5 = any(make_array(1, 2, 3));
914+
----
915+
false
916+
917+
# = ANY with mixed NULL (satisfying) returns TRUE
918+
query B
919+
select 5 = any(make_array(5, NULL));
920+
----
921+
true
922+
923+
# = ANY with mixed NULL (non-satisfying): NULLs leave result indeterminate
924+
query B
925+
select 5 = any(make_array(1, 2, NULL));
926+
----
927+
NULL
928+
893929
statement ok
894930
DROP TABLE any_op_test;
895931

0 commit comments

Comments
 (0)