@@ -40,7 +40,9 @@ use std::{any::Any, sync::Arc};
4040
4141use crate :: expressions:: case:: literal_lookup_table:: LiteralLookupTable ;
4242use arrow:: compute:: kernels:: merge:: { MergeIndex , merge, merge_n} ;
43- use datafusion_common:: tree_node:: { Transformed , TreeNode , TreeNodeRecursion } ;
43+ use datafusion_common:: tree_node:: {
44+ ScopedTreeNode , Transformed , TreeNode , TreeNodeRecursion ,
45+ } ;
4446use datafusion_physical_expr_common:: datum:: compare_with_eq;
4547use datafusion_physical_expr_common:: utils:: scatter;
4648use itertools:: Itertools ;
@@ -130,7 +132,7 @@ impl CaseBody {
130132 // Determine the set of columns that are used in all the expressions of the case body.
131133 let mut used_column_indices = IndexSet :: < usize > :: new ( ) ;
132134 let mut collect_column_indices = |expr : & Arc < dyn PhysicalExpr > | {
133- expr. apply ( |expr| {
135+ expr. apply_in_scope ( |expr| {
134136 if let Some ( column) = expr. as_any ( ) . downcast_ref :: < Column > ( ) {
135137 used_column_indices. insert ( column. index ( ) ) ;
136138 }
@@ -161,7 +163,7 @@ impl CaseBody {
161163 // using the column index mapping.
162164 let project = |expr : & Arc < dyn PhysicalExpr > | -> Result < Arc < dyn PhysicalExpr > > {
163165 Arc :: clone ( expr)
164- . transform_down ( |e| {
166+ . transform_down_in_scope ( |e| {
165167 if let Some ( column) = e. as_any ( ) . downcast_ref :: < Column > ( ) {
166168 let original = column. index ( ) ;
167169 let projected = * column_index_map. get ( & original) . unwrap ( ) ;
@@ -1397,7 +1399,7 @@ fn replace_with_null(
13971399 input_schema : & Schema ,
13981400) -> Result < Arc < dyn PhysicalExpr > , DataFusionError > {
13991401 let with_null = Arc :: clone ( expr)
1400- . transform_down ( |e| {
1402+ . transform_down_in_scope ( |e| {
14011403 if e. as_ref ( ) . dyn_eq ( expr_to_replace) {
14021404 let data_type = e. data_type ( input_schema) ?;
14031405 let null_literal = lit ( ScalarValue :: try_new_null ( & data_type) ?) ;
@@ -1928,135 +1930,6 @@ mod tests {
19281930 Ok ( ( ) )
19291931 }
19301932
1931- #[ test]
1932- fn case_without_expr_and_with_custom_column_impl ( ) -> Result < ( ) > {
1933- /// Represents the column at a given index in a RecordBatch that is inside a Spark lambda function
1934- ///
1935- /// This is the same as the datafusion [`datafusion::physical_expr::expressions::Column`] except that it store the entire info so that it can be used in lambda execution
1936- #[ derive( Debug , Hash , PartialEq , Eq , Clone ) ]
1937- pub struct CustomColumn {
1938- /// The name of the column (used for debugging and display purposes)
1939- name : String ,
1940- /// The index of the column in its schema
1941- index : usize ,
1942- data_type : DataType ,
1943- nullable : bool ,
1944- }
1945-
1946- impl CustomColumn {
1947- pub fn new_with_schema (
1948- name : & str ,
1949- schema : & Schema ,
1950- ) -> Result < Arc < dyn PhysicalExpr > > {
1951- let index = schema. index_of ( name) ?;
1952- let field = schema. field ( index) ;
1953- Ok ( Arc :: new ( CustomColumn {
1954- name : name. to_string ( ) ,
1955- index,
1956- data_type : field. data_type ( ) . clone ( ) ,
1957- nullable : field. is_nullable ( ) ,
1958- } ) )
1959- }
1960- }
1961-
1962- impl std:: fmt:: Display for CustomColumn {
1963- fn fmt ( & self , f : & mut Formatter ) -> std:: fmt:: Result {
1964- write ! ( f, "{}@{}" , self . name, self . index)
1965- }
1966- }
1967-
1968- impl PhysicalExpr for CustomColumn {
1969- fn as_any ( & self ) -> & dyn Any {
1970- self
1971- }
1972-
1973- fn data_type ( & self , _input_schema : & Schema ) -> Result < DataType > {
1974- Ok ( self . data_type . clone ( ) )
1975- }
1976-
1977- fn nullable ( & self , _input_schema : & Schema ) -> Result < bool > {
1978- Ok ( self . nullable )
1979- }
1980-
1981- fn evaluate ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
1982- self . bounds_check ( batch. schema ( ) . as_ref ( ) ) ?;
1983- Ok ( ColumnarValue :: Array ( Arc :: clone ( batch. column ( self . index ) ) ) )
1984- }
1985-
1986- fn children ( & self ) -> Vec < & Arc < dyn PhysicalExpr > > {
1987- vec ! [ ]
1988- }
1989-
1990- fn with_new_children (
1991- self : Arc < Self > ,
1992- _children : Vec < Arc < dyn PhysicalExpr > > ,
1993- ) -> Result < Arc < dyn PhysicalExpr > > {
1994- Ok ( self )
1995- }
1996-
1997- fn fmt_sql ( & self , _: & mut Formatter < ' _ > ) -> std:: fmt:: Result {
1998- unimplemented ! ( )
1999- }
2000- }
2001-
2002- impl CustomColumn {
2003- fn bounds_check ( & self , input_schema : & Schema ) -> Result < ( ) > {
2004- if self . index < input_schema. fields . len ( ) {
2005- Ok ( ( ) )
2006- } else {
2007- internal_err ! (
2008- "PhysicalExpr BoundLambdaColumn references column '{}' at index {} (zero-based) but input schema only has {} columns: {:?}" ,
2009- self . name,
2010- self . index,
2011- input_schema. fields. len( ) ,
2012- input_schema
2013- . fields( )
2014- . iter( )
2015- . map( |f| f. name( ) )
2016- . collect:: <Vec <_>>( )
2017- )
2018- }
2019- }
2020- }
2021-
2022- let batch = case_test_batch ( ) ?;
2023- let schema = batch. schema ( ) ;
2024-
2025- // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
2026- let when1 = binary (
2027- CustomColumn :: new_with_schema ( "a" , & schema) ?,
2028- Operator :: Eq ,
2029- lit ( "foo" ) ,
2030- & batch. schema ( ) ,
2031- ) ?;
2032- let then1 = lit ( 123i32 ) ;
2033- let when2 = binary (
2034- CustomColumn :: new_with_schema ( "a" , & schema) ?,
2035- Operator :: Eq ,
2036- lit ( "bar" ) ,
2037- & batch. schema ( ) ,
2038- ) ?;
2039- let then2 = lit ( 456i32 ) ;
2040-
2041- let expr = generate_case_when_with_type_coercion (
2042- None ,
2043- vec ! [ ( when1, then1) , ( when2, then2) ] ,
2044- None ,
2045- schema. as_ref ( ) ,
2046- ) ?;
2047- let result = expr
2048- . evaluate ( & batch) ?
2049- . into_array ( batch. num_rows ( ) )
2050- . expect ( "Failed to convert to array" ) ;
2051- let result = as_int32_array ( & result) ?;
2052-
2053- let expected = & Int32Array :: from ( vec ! [ Some ( 123 ) , None , None , Some ( 456 ) ] ) ;
2054-
2055- assert_eq ! ( expected, result) ;
2056-
2057- Ok ( ( ) )
2058- }
2059-
20601933 #[ test]
20611934 fn case_with_expr_when_null ( ) -> Result < ( ) > {
20621935 let batch = case_test_batch ( ) ?;
@@ -2552,7 +2425,7 @@ mod tests {
25522425 . unwrap ( ) ;
25532426
25542427 let expr3 = Arc :: clone ( & expr)
2555- . transform_down ( |e| {
2428+ . transform_down_in_scope ( |e| {
25562429 let transformed = match e. as_any ( ) . downcast_ref :: < Literal > ( ) {
25572430 Some ( lit_value) => match lit_value. value ( ) {
25582431 ScalarValue :: Utf8 ( Some ( str_value) ) => {
0 commit comments