@@ -42,6 +42,7 @@ use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
4242use arrow:: compute:: kernels:: merge:: { MergeIndex , merge, merge_n} ;
4343use datafusion_common:: tree_node:: { Transformed , TreeNode , TreeNodeRecursion } ;
4444use datafusion_physical_expr_common:: datum:: compare_with_eq;
45+ use datafusion_physical_expr_common:: utils:: scatter;
4546use itertools:: Itertools ;
4647use std:: fmt:: { Debug , Formatter } ;
4748
@@ -64,17 +65,21 @@ enum EvalMethod {
6465 /// for expressions that are infallible and can be cheaply computed for the entire
6566 /// record batch rather than just for the rows where the predicate is true.
6667 ///
67- /// CASE WHEN condition THEN column [ELSE NULL] END
68+ /// CASE WHEN condition THEN infallible_expression [ELSE NULL] END
6869 InfallibleExprOrNull ,
6970 /// This is a specialization for a specific use case where we can take a fast path
7071 /// if there is just one when/then pair and both the `then` and `else` expressions
7172 /// are literal values
7273 /// CASE WHEN condition THEN literal ELSE literal END
7374 ScalarOrScalar ,
7475 /// This is a specialization for a specific use case where we can take a fast path
75- /// if there is just one when/then pair and both the `then` and `else` are expressions
76+ /// if there is just one when/then pair, the `then` is an expression, and `else` is either
77+ /// an expression, literal NULL or absent.
7678 ///
77- /// CASE WHEN condition THEN expression ELSE expression END
79+ /// In contrast to [`EvalMethod::InfallibleExprOrNull`], this specialization can handle fallible
80+ /// `then` expressions.
81+ ///
82+ /// CASE WHEN condition THEN expression [ELSE expression] END
7883 ExpressionOrExpression ( ProjectedCaseBody ) ,
7984
8085 /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals
@@ -659,7 +664,7 @@ impl CaseExpr {
659664 && body. else_expr . as_ref ( ) . unwrap ( ) . as_any ( ) . is :: < Literal > ( )
660665 {
661666 EvalMethod :: ScalarOrScalar
662- } else if body. when_then_expr . len ( ) == 1 && body . else_expr . is_some ( ) {
667+ } else if body. when_then_expr . len ( ) == 1 {
663668 EvalMethod :: ExpressionOrExpression ( body. project ( ) ?)
664669 } else {
665670 EvalMethod :: NoExpression ( body. project ( ) ?)
@@ -961,32 +966,40 @@ impl CaseBody {
961966 let then_batch = filter_record_batch ( batch, & when_filter) ?;
962967 let then_value = self . when_then_expr [ 0 ] . 1 . evaluate ( & then_batch) ?;
963968
964- let else_selection = not ( & when_value) ?;
965- let else_filter = create_filter ( & else_selection, optimize_filter) ;
966- let else_batch = filter_record_batch ( batch, & else_filter) ?;
967-
968- // keep `else_expr`'s data type and return type consistent
969- let e = self . else_expr . as_ref ( ) . unwrap ( ) ;
970- let return_type = self . data_type ( & batch. schema ( ) ) ?;
971- let else_expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) )
972- . unwrap_or_else ( |_| Arc :: clone ( e) ) ;
973-
974- let else_value = else_expr. evaluate ( & else_batch) ?;
975-
976- Ok ( ColumnarValue :: Array ( match ( then_value, else_value) {
977- ( ColumnarValue :: Array ( t) , ColumnarValue :: Array ( e) ) => {
978- merge ( & when_value, & t, & e)
979- }
980- ( ColumnarValue :: Scalar ( t) , ColumnarValue :: Array ( e) ) => {
981- merge ( & when_value, & t. to_scalar ( ) ?, & e)
982- }
983- ( ColumnarValue :: Array ( t) , ColumnarValue :: Scalar ( e) ) => {
984- merge ( & when_value, & t, & e. to_scalar ( ) ?)
969+ match & self . else_expr {
970+ None => {
971+ let then_array = then_value. to_array ( when_value. true_count ( ) ) ?;
972+ scatter ( & when_value, then_array. as_ref ( ) ) . map ( ColumnarValue :: Array )
985973 }
986- ( ColumnarValue :: Scalar ( t) , ColumnarValue :: Scalar ( e) ) => {
987- merge ( & when_value, & t. to_scalar ( ) ?, & e. to_scalar ( ) ?)
974+ Some ( else_expr) => {
975+ let else_selection = not ( & when_value) ?;
976+ let else_filter = create_filter ( & else_selection, optimize_filter) ;
977+ let else_batch = filter_record_batch ( batch, & else_filter) ?;
978+
979+ // keep `else_expr`'s data type and return type consistent
980+ let return_type = self . data_type ( & batch. schema ( ) ) ?;
981+ let else_expr =
982+ try_cast ( Arc :: clone ( else_expr) , & batch. schema ( ) , return_type. clone ( ) )
983+ . unwrap_or_else ( |_| Arc :: clone ( else_expr) ) ;
984+
985+ let else_value = else_expr. evaluate ( & else_batch) ?;
986+
987+ Ok ( ColumnarValue :: Array ( match ( then_value, else_value) {
988+ ( ColumnarValue :: Array ( t) , ColumnarValue :: Array ( e) ) => {
989+ merge ( & when_value, & t, & e)
990+ }
991+ ( ColumnarValue :: Scalar ( t) , ColumnarValue :: Array ( e) ) => {
992+ merge ( & when_value, & t. to_scalar ( ) ?, & e)
993+ }
994+ ( ColumnarValue :: Array ( t) , ColumnarValue :: Scalar ( e) ) => {
995+ merge ( & when_value, & t, & e. to_scalar ( ) ?)
996+ }
997+ ( ColumnarValue :: Scalar ( t) , ColumnarValue :: Scalar ( e) ) => {
998+ merge ( & when_value, & t. to_scalar ( ) ?, & e. to_scalar ( ) ?)
999+ }
1000+ } ?) )
9881001 }
989- } ? ) )
1002+ }
9901003 }
9911004}
9921005
@@ -1137,7 +1150,15 @@ impl CaseExpr {
11371150 self . body . when_then_expr [ 0 ] . 1 . evaluate ( batch)
11381151 } else if true_count == 0 {
11391152 // All input rows are false/null, just call the 'else' expression
1140- self . body . else_expr . as_ref ( ) . unwrap ( ) . evaluate ( batch)
1153+ match & self . body . else_expr {
1154+ Some ( else_expr) => else_expr. evaluate ( batch) ,
1155+ None => {
1156+ let return_type = self . data_type ( & batch. schema ( ) ) ?;
1157+ Ok ( ColumnarValue :: Scalar ( ScalarValue :: try_new_null (
1158+ & return_type,
1159+ ) ?) )
1160+ }
1161+ }
11411162 } else if projected. projection . len ( ) < batch. num_columns ( ) {
11421163 // The case expressions do not use all the columns of the input batch.
11431164 // Project first to reduce time spent filtering.
0 commit comments