@@ -27,13 +27,13 @@ use arrow::{
2727 datatypes:: { DataType , Schema } ,
2828 record_batch:: RecordBatch ,
2929} ;
30- use datafusion_common:: { internal_err, tree_node:: TreeNodeVisitor , HashSet , Result } ;
30+ use datafusion_common:: { HashSet , Result , internal_err, tree_node:: TreeNodeVisitor } ;
3131use datafusion_common:: {
3232 plan_err,
3333 tree_node:: { TreeNode , TreeNodeRecursion } ,
3434} ;
3535use datafusion_expr:: ColumnarValue ;
36- use hashbrown:: { hash_map:: EntryRef , HashMap } ;
36+ use hashbrown:: { HashMap , hash_map:: EntryRef } ;
3737
3838/// Represents a lambda with the given parameters names and body
3939#[ derive( Debug , Eq , Clone ) ]
@@ -100,6 +100,14 @@ impl LambdaExpr {
100100 & self . captured_columns
101101 }
102102
103+ /// Returns lambdas variables names that aren't of this lambda nor any other lambda down tree.
104+ /// Example:
105+ ///
106+ /// `array_transform([[[1, 2, 3]]], a -> array_transform(a, b -> array_transform(b, c -> length(a) + length(b) + c)))`
107+ ///
108+ /// For the outermost lambda, this would return an empty hash set
109+ /// For the middle one, `HashSet("a")`
110+ /// And for the innermost, `HashSet("a", "b")`
103111 pub ( crate ) fn captured_variables ( & self ) -> & HashSet < String > {
104112 & self . captured_variables
105113 }
@@ -192,7 +200,7 @@ impl<'n> TreeNodeVisitor<'n> for Captures<'n> {
192200 fn f_down ( & mut self , node : & ' n Self :: Node ) -> Result < TreeNodeRecursion > {
193201 if let Some ( lambda) = node. as_any ( ) . downcast_ref :: < LambdaExpr > ( ) {
194202 for param in & lambda. params {
195- * self . shadows . entry_ref ( param. as_str ( ) ) . or_default ( ) += 1 ;
203+ * self . shadows . entry ( param. as_str ( ) ) . or_default ( ) += 1 ;
196204 }
197205 } else if let Some ( lambda_variable) =
198206 node. as_any ( ) . downcast_ref :: < LambdaVariable > ( )
@@ -230,3 +238,101 @@ impl<'n> TreeNodeVisitor<'n> for Captures<'n> {
230238 Ok ( TreeNodeRecursion :: Continue )
231239 }
232240}
241+
242+ #[ cfg( test) ]
243+ mod tests {
244+ use crate :: {
245+ LambdaFunctionExpr ,
246+ expressions:: { Column , LambdaExpr , NoOp , lambda:: lambda, lambda_variable} ,
247+ } ;
248+ use arrow:: {
249+ array:: RecordBatch ,
250+ datatypes:: { DataType , Field , FieldRef , Schema } ,
251+ } ;
252+ use datafusion_common:: { HashSet , Result } ;
253+ use datafusion_expr:: { ColumnarValue , LambdaUDF } ;
254+ use std:: sync:: Arc ;
255+
256+ #[ derive( Debug , Hash , Eq , PartialEq ) ]
257+ struct DummyLambdaUDF ;
258+
259+ impl LambdaUDF for DummyLambdaUDF {
260+ fn as_any ( & self ) -> & dyn std:: any:: Any {
261+ unimplemented ! ( )
262+ }
263+
264+ fn name ( & self ) -> & str {
265+ "dummy_udlf"
266+ }
267+
268+ fn signature ( & self ) -> & datafusion_expr:: LambdaSignature {
269+ unimplemented ! ( )
270+ }
271+
272+ fn lambdas_parameters (
273+ & self ,
274+ _args : & [ datafusion_expr:: ValueOrLambda < FieldRef , ( ) > ] ,
275+ ) -> Result < Vec < Option < Vec < Field > > > > {
276+ unimplemented ! ( )
277+ }
278+
279+ fn return_field_from_args (
280+ & self ,
281+ _args : datafusion_expr:: LambdaReturnFieldArgs ,
282+ ) -> Result < FieldRef > {
283+ unimplemented ! ( )
284+ }
285+
286+ fn invoke_with_args (
287+ & self ,
288+ _args : datafusion_expr:: LambdaFunctionArgs ,
289+ ) -> Result < ColumnarValue > {
290+ unimplemented ! ( )
291+ }
292+ }
293+
294+ #[ test]
295+ fn test_lambda_captures ( ) {
296+ let null_field = Arc :: new ( Field :: new ( "" , DataType :: Null , true ) ) ;
297+
298+ //`var_b -> dummy_udlf(var_a, var_b, column@0, var_c -> var_c))`
299+ let inner = LambdaExpr :: try_new (
300+ vec ! [ String :: from( "var_b" ) ] ,
301+ Arc :: new ( LambdaFunctionExpr :: new (
302+ "dummy_udlf" ,
303+ Arc :: new ( DummyLambdaUDF ) ,
304+ vec ! [
305+ lambda_variable( "var_a" , Arc :: clone( & null_field) ) . unwrap( ) ,
306+ lambda_variable( "var_b" , Arc :: clone( & null_field) ) . unwrap( ) ,
307+ Arc :: new( Column :: new( "column" , 0 ) ) ,
308+ lambda(
309+ [ "var_c" ] ,
310+ lambda_variable( "var_c" , Arc :: clone( & null_field) ) . unwrap( ) ,
311+ )
312+ . unwrap( ) ,
313+ ] ,
314+ Arc :: clone ( & null_field) ,
315+ Arc :: new ( Default :: default ( ) ) ,
316+ ) ) ,
317+ )
318+ . unwrap ( ) ;
319+
320+ assert_eq ! ( inner. captured_columns( ) , & HashSet :: from( [ 0 ] ) ) ;
321+ assert_eq ! (
322+ inner. captured_variables( ) ,
323+ & HashSet :: from( [ String :: from( "var_a" ) ] )
324+ ) ;
325+ }
326+
327+ #[ test]
328+ fn test_lambda_evaluate ( ) {
329+ let lambda = lambda ( [ "a" ] , Arc :: new ( NoOp :: new ( ) ) ) . unwrap ( ) ;
330+ let batch = RecordBatch :: new_empty ( Arc :: new ( Schema :: empty ( ) ) ) ;
331+ assert ! ( lambda. evaluate( & batch) . is_err( ) ) ;
332+ }
333+
334+ #[ test]
335+ fn test_lambda_duplicate_name ( ) {
336+ assert ! ( lambda( [ "a" , "a" ] , Arc :: new( NoOp :: new( ) ) ) . is_err( ) ) ;
337+ }
338+ }
0 commit comments