@@ -23,21 +23,24 @@ use super::{
2323 from_substrait_rex, from_window_function,
2424} ;
2525use crate :: extensions:: Extensions ;
26+ use crate :: logical_plan:: consumer:: from_substrait_type_without_names;
2627use async_trait:: async_trait;
27- use datafusion:: arrow:: datatypes:: DataType ;
28+ use datafusion:: arrow:: datatypes:: { DataType , Field , FieldRef } ;
2829use datafusion:: catalog:: TableProvider ;
2930use datafusion:: common:: {
3031 DFSchema , ScalarValue , TableReference , not_impl_err, substrait_err,
3132} ;
3233use datafusion:: execution:: { FunctionRegistry , SessionState } ;
3334use datafusion:: logical_expr:: { Expr , Extension , LogicalPlan } ;
35+ use datafusion:: prelude:: { lambda, lambda_var} ;
36+ use std:: collections:: VecDeque ;
3437use std:: sync:: { Arc , RwLock } ;
35- use substrait:: proto;
3638use substrait:: proto:: expression as substrait_expression;
3739use substrait:: proto:: expression:: {
3840 Enum , FieldReference , IfThen , Literal , MultiOrList , Nested , ScalarFunction ,
3941 SingularOrList , SwitchExpression , WindowFunction ,
4042} ;
43+ use substrait:: proto:: { self , Type } ;
4144use substrait:: proto:: {
4245 AggregateRel , ConsistentPartitionWindowRel , CrossRel , DynamicParameter , ExchangeRel ,
4346 Expression , ExtensionLeafRel , ExtensionMultiRel , ExtensionSingleRel , FetchRel ,
@@ -372,6 +375,14 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
372375 not_impl_err ! ( "Dynamic Parameter expression not supported" )
373376 }
374377
378+ async fn consume_lambda (
379+ & self ,
380+ expr : & proto:: expression:: Lambda ,
381+ input_schema : & DFSchema ,
382+ ) -> datafusion:: common:: Result < Expr > {
383+ from_lambda ( self , expr, input_schema) . await
384+ }
385+
375386 // Outer Schema Stack
376387 // These methods manage a stack of outer schemas for correlated subquery support.
377388 // When entering a subquery, the enclosing query's schema is pushed onto the stack.
@@ -469,6 +480,32 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
469480 } ;
470481 substrait_err ! ( "Missing handler for user-defined literals {}" , type_ref)
471482 }
483+
484+ fn with_lambda_parameters (
485+ & self ,
486+ lambda_parameters : & [ Type ] ,
487+ ) -> datafusion:: common:: Result < ( Vec < String > , Self ) > ;
488+
489+ fn lambda_variable (
490+ & self ,
491+ steps_out : usize ,
492+ field_idx : usize ,
493+ ) -> datafusion:: common:: Result < Expr > ;
494+ }
495+
496+ async fn from_lambda (
497+ consumer : & impl SubstraitConsumer ,
498+ expr : & proto:: expression:: Lambda ,
499+ input_schema : & DFSchema ,
500+ ) -> datafusion:: common:: Result < Expr > {
501+ let parameters = expr. parameters . as_ref ( ) . unwrap ( ) ;
502+
503+ let ( names, consumer) = consumer. with_lambda_parameters ( & parameters. types ) ?;
504+
505+ let body = expr. body . as_ref ( ) . unwrap ( ) ;
506+ let body = consumer. consume_expression ( body, input_schema) . await ?;
507+
508+ Ok ( lambda ( names, body) )
472509}
473510
474511/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions.
@@ -478,6 +515,8 @@ pub struct DefaultSubstraitConsumer<'a> {
478515 pub ( super ) extensions : & ' a Extensions ,
479516 pub ( super ) state : & ' a SessionState ,
480517 outer_schemas : RwLock < Vec < Arc < DFSchema > > > ,
518+ lambdas_parameters : VecDeque < Vec < FieldRef > > ,
519+ num_lambda_parameters : usize ,
481520}
482521
483522impl < ' a > DefaultSubstraitConsumer < ' a > {
@@ -486,12 +525,61 @@ impl<'a> DefaultSubstraitConsumer<'a> {
486525 extensions,
487526 state,
488527 outer_schemas : RwLock :: new ( Vec :: new ( ) ) ,
528+ lambdas_parameters : VecDeque :: new ( ) ,
529+ num_lambda_parameters : 0 ,
489530 }
490531 }
491532}
492533
493534#[ async_trait]
494535impl SubstraitConsumer for DefaultSubstraitConsumer < ' _ > {
536+ fn with_lambda_parameters (
537+ & self ,
538+ lambda_parameters : & [ Type ] ,
539+ ) -> datafusion:: common:: Result < ( Vec < String > , Self ) > {
540+ let mut lambdas_parameters = self . lambdas_parameters . clone ( ) ;
541+
542+ let lambda_parameters = lambda_parameters
543+ . iter ( )
544+ . enumerate ( )
545+ . map ( |( i, ty) | {
546+ let dt = from_substrait_type_without_names ( self , ty) ?;
547+
548+ Ok ( Arc :: new ( Field :: new (
549+ format ! ( "p{}" , i + self . num_lambda_parameters) ,
550+ dt,
551+ true ,
552+ ) ) )
553+ } )
554+ . collect :: < datafusion:: common:: Result < Vec < _ > > > ( ) ?;
555+
556+ let names = lambda_parameters. iter ( ) . map ( |f| f. name ( ) . clone ( ) ) . collect ( ) ;
557+ let num_lambda_parameters = self . num_lambda_parameters + lambda_parameters. len ( ) ;
558+
559+ lambdas_parameters. push_front ( lambda_parameters) ;
560+
561+ Ok ( (
562+ names,
563+ Self {
564+ extensions : self . extensions ,
565+ state : self . state ,
566+ outer_schemas : RwLock :: new ( self . outer_schemas . read ( ) . unwrap ( ) . clone ( ) ) ,
567+ lambdas_parameters,
568+ num_lambda_parameters,
569+ } ,
570+ ) )
571+ }
572+
573+ fn lambda_variable (
574+ & self ,
575+ steps_out : usize ,
576+ field_idx : usize ,
577+ ) -> datafusion:: common:: Result < Expr > {
578+ let var = & self . lambdas_parameters [ steps_out] [ field_idx] ;
579+
580+ Ok ( lambda_var ( var. name ( ) , Arc :: clone ( var) ) )
581+ }
582+
495583 async fn resolve_table_ref (
496584 & self ,
497585 table_ref : & TableReference ,
0 commit comments