Skip to content

Commit 4dfc193

Browse files
authored
Improve performance of CASE WHEN x THEN y ELSE NULL expressions (#20097)
## Which issue does this PR close? - Related to #11570 ## Rationale for this change While reviewing #19994 it became clear the optimised `ExpressionOrExpression` code path was not being used when the case expression has no `else` branch or has `else null`. In those situations the general evaluation strategies could end up being used. This PR refines the `ExpressionOrExpression` implementation to also handle `else null` expressions. ## What changes are included in this PR? Use `ExpressionOrExpression` for expressions of the form `CASE WHEN x THEN y [ELSE NULL]` ## Are these changes tested? Covered by existing SLTs ## Are there any user-facing changes? No
1 parent 35e78ca commit 4dfc193

2 files changed

Lines changed: 82 additions & 29 deletions

File tree

  • datafusion

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
4242
use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
4343
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
4444
use datafusion_physical_expr_common::datum::compare_with_eq;
45+
use datafusion_physical_expr_common::utils::scatter;
4546
use itertools::Itertools;
4647
use std::fmt::{Debug, Formatter};
4748

@@ -64,17 +65,21 @@ enum EvalMethod {
6465
/// for expressions that are infallible and can be cheaply computed for the entire
6566
/// record batch rather than just for the rows where the predicate is true.
6667
///
67-
/// CASE WHEN condition THEN column [ELSE NULL] END
68+
/// CASE WHEN condition THEN infallible_expression [ELSE NULL] END
6869
InfallibleExprOrNull,
6970
/// This is a specialization for a specific use case where we can take a fast path
7071
/// if there is just one when/then pair and both the `then` and `else` expressions
7172
/// are literal values
7273
/// CASE WHEN condition THEN literal ELSE literal END
7374
ScalarOrScalar,
7475
/// This is a specialization for a specific use case where we can take a fast path
75-
/// if there is just one when/then pair and both the `then` and `else` are expressions
76+
/// if there is just one when/then pair, the `then` is an expression, and `else` is either
77+
/// an expression, literal NULL or absent.
7678
///
77-
/// CASE WHEN condition THEN expression ELSE expression END
79+
/// In contrast to [`EvalMethod::InfallibleExprOrNull`], this specialization can handle fallible
80+
/// `then` expressions.
81+
///
82+
/// CASE WHEN condition THEN expression [ELSE expression] END
7883
ExpressionOrExpression(ProjectedCaseBody),
7984

8085
/// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals
@@ -659,7 +664,7 @@ impl CaseExpr {
659664
&& body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
660665
{
661666
EvalMethod::ScalarOrScalar
662-
} else if body.when_then_expr.len() == 1 && body.else_expr.is_some() {
667+
} else if body.when_then_expr.len() == 1 {
663668
EvalMethod::ExpressionOrExpression(body.project()?)
664669
} else {
665670
EvalMethod::NoExpression(body.project()?)
@@ -961,32 +966,40 @@ impl CaseBody {
961966
let then_batch = filter_record_batch(batch, &when_filter)?;
962967
let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
963968

964-
let else_selection = not(&when_value)?;
965-
let else_filter = create_filter(&else_selection, optimize_filter);
966-
let else_batch = filter_record_batch(batch, &else_filter)?;
967-
968-
// keep `else_expr`'s data type and return type consistent
969-
let e = self.else_expr.as_ref().unwrap();
970-
let return_type = self.data_type(&batch.schema())?;
971-
let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
972-
.unwrap_or_else(|_| Arc::clone(e));
973-
974-
let else_value = else_expr.evaluate(&else_batch)?;
975-
976-
Ok(ColumnarValue::Array(match (then_value, else_value) {
977-
(ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
978-
merge(&when_value, &t, &e)
979-
}
980-
(ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
981-
merge(&when_value, &t.to_scalar()?, &e)
982-
}
983-
(ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
984-
merge(&when_value, &t, &e.to_scalar()?)
969+
match &self.else_expr {
970+
None => {
971+
let then_array = then_value.to_array(when_value.true_count())?;
972+
scatter(&when_value, then_array.as_ref()).map(ColumnarValue::Array)
985973
}
986-
(ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
987-
merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
974+
Some(else_expr) => {
975+
let else_selection = not(&when_value)?;
976+
let else_filter = create_filter(&else_selection, optimize_filter);
977+
let else_batch = filter_record_batch(batch, &else_filter)?;
978+
979+
// keep `else_expr`'s data type and return type consistent
980+
let return_type = self.data_type(&batch.schema())?;
981+
let else_expr =
982+
try_cast(Arc::clone(else_expr), &batch.schema(), return_type.clone())
983+
.unwrap_or_else(|_| Arc::clone(else_expr));
984+
985+
let else_value = else_expr.evaluate(&else_batch)?;
986+
987+
Ok(ColumnarValue::Array(match (then_value, else_value) {
988+
(ColumnarValue::Array(t), ColumnarValue::Array(e)) => {
989+
merge(&when_value, &t, &e)
990+
}
991+
(ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => {
992+
merge(&when_value, &t.to_scalar()?, &e)
993+
}
994+
(ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => {
995+
merge(&when_value, &t, &e.to_scalar()?)
996+
}
997+
(ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => {
998+
merge(&when_value, &t.to_scalar()?, &e.to_scalar()?)
999+
}
1000+
}?))
9881001
}
989-
}?))
1002+
}
9901003
}
9911004
}
9921005

@@ -1137,7 +1150,15 @@ impl CaseExpr {
11371150
self.body.when_then_expr[0].1.evaluate(batch)
11381151
} else if true_count == 0 {
11391152
// All input rows are false/null, just call the 'else' expression
1140-
self.body.else_expr.as_ref().unwrap().evaluate(batch)
1153+
match &self.body.else_expr {
1154+
Some(else_expr) => else_expr.evaluate(batch),
1155+
None => {
1156+
let return_type = self.data_type(&batch.schema())?;
1157+
Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
1158+
&return_type,
1159+
)?))
1160+
}
1161+
}
11411162
} else if projected.projection.len() < batch.num_columns() {
11421163
// The case expressions do not use all the columns of the input batch.
11431164
// Project first to reduce time spent filtering.

datafusion/sqllogictest/test_files/case.slt

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,38 @@ NULL
642642
NULL
643643
-1
644644

645+
# single WHEN, no ELSE (absent)
646+
query I
647+
SELECT CASE WHEN a > 0 THEN b END
648+
FROM (VALUES (1, 10), (0, 20)) AS t(a, b);
649+
----
650+
10
651+
NULL
652+
653+
# single WHEN, explicit ELSE NULL
654+
query I
655+
SELECT CASE WHEN a > 0 THEN b ELSE NULL END
656+
FROM (VALUES (1, 10), (0, 20)) AS t(a, b);
657+
----
658+
10
659+
NULL
660+
661+
# fallible THEN expression should only be evaluated on true rows
662+
query I
663+
SELECT CASE WHEN a > 0 THEN 10 / a END
664+
FROM (VALUES (1), (0)) AS t(a);
665+
----
666+
10
667+
NULL
668+
669+
# all-false path returns typed NULLs
670+
query I
671+
SELECT CASE WHEN a < 0 THEN b END
672+
FROM (VALUES (1, 10), (2, 20)) AS t(a, b);
673+
----
674+
NULL
675+
NULL
676+
645677
# EvalMethod::WithExpression using subset of all selected columns in case expression
646678
query III
647679
SELECT CASE a1 WHEN 1 THEN a1 WHEN 2 THEN a2 WHEN 3 THEN b END, b, c

0 commit comments

Comments
 (0)