@@ -29,9 +29,7 @@ use crate::logical_plan::producer::{
2929 to_substrait_type_from_field,
3030} ;
3131use datafusion:: arrow:: datatypes:: Field ;
32- use datafusion:: common:: {
33- Column , DFSchemaRef , HashMap , ScalarValue , substrait_datafusion_err, substrait_err,
34- } ;
32+ use datafusion:: common:: { Column , DFSchemaRef , HashMap , ScalarValue , substrait_err} ;
3533use datafusion:: execution:: SessionState ;
3634use datafusion:: execution:: registry:: SerializerRegistry ;
3735use datafusion:: logical_expr:: Subquery ;
@@ -52,7 +50,7 @@ use substrait::proto::expression::field_reference::{
5250use substrait:: proto:: expression:: reference_segment:: { self , StructField } ;
5351use substrait:: proto:: expression:: { FieldReference , ReferenceSegment , RexType } ;
5452use substrait:: proto:: rel:: RelType ;
55- use substrait:: proto:: r#type:: Struct ;
53+ use substrait:: proto:: r#type:: { Nullability , Struct } ;
5654use substrait:: proto:: {
5755 Expression , ExtensionLeafRel , ExtensionMultiRel , ExtensionSingleRel , Rel ,
5856} ;
@@ -426,12 +424,14 @@ pub trait SubstraitProducer: Send + Sync + Sized {
426424 from_lambda_variable ( self , lambda_variable, schema)
427425 }
428426
429- fn with_lambda_parameters (
427+ fn push_lambda_parameters (
430428 & mut self ,
431429 lambda_parameters : Vec < Field > ,
432- ) -> datafusion:: common:: Result < Self > ;
430+ ) -> datafusion:: common:: Result < ( ) > ;
431+
432+ fn pop_lambda_parameters ( & mut self ) -> datafusion:: common:: Result < ( ) > ;
433433
434- fn lambda_variable ( & self , name : & str ) -> datafusion:: common:: Result < ( i32 , u32 ) > ;
434+ fn lambda_variable ( & self , name : & str ) -> datafusion:: common:: Result < ( u32 , i32 ) > ;
435435
436436 fn lambda_parameter_type (
437437 & self ,
@@ -444,7 +444,7 @@ fn from_lambda_variable(
444444 lambda_variable : & LambdaVariable ,
445445 _schema : & datafusion:: common:: DFSchema ,
446446) -> Result < Expression , datafusion:: error:: DataFusionError > {
447- let ( field , steps_out ) = producer. lambda_variable ( & lambda_variable. name ) ?;
447+ let ( steps_out , field ) = producer. lambda_variable ( & lambda_variable. name ) ?;
448448
449449 Ok ( Expression {
450450 rex_type : Some ( RexType :: Selection ( Box :: new ( FieldReference {
@@ -469,7 +469,7 @@ fn from_lambda(
469469 rex_type : Some ( RexType :: Lambda ( Box :: new (
470470 substrait:: proto:: expression:: Lambda {
471471 parameters : Some ( Struct {
472- nullability : 1 ,
472+ nullability : Nullability :: Required as i32 ,
473473 type_variation_reference : 0 ,
474474 types : lambda
475475 . params
@@ -486,15 +486,15 @@ fn from_lambda(
486486pub struct DefaultSubstraitProducer < ' a > {
487487 extensions : Extensions ,
488488 serializer_registry : & ' a dyn SerializerRegistry ,
489- lambdas_variables : HashMap < String , ( i32 , u32 , substrait:: proto:: Type ) > ,
489+ lambdas_variables : Vec < HashMap < String , ( usize , substrait:: proto:: Type ) > > ,
490490}
491491
492492impl < ' a > DefaultSubstraitProducer < ' a > {
493493 pub fn new ( state : & ' a SessionState ) -> Self {
494494 DefaultSubstraitProducer {
495495 extensions : Extensions :: default ( ) ,
496496 serializer_registry : state. serializer_registry ( ) . as_ref ( ) ,
497- lambdas_variables : HashMap :: new ( ) ,
497+ lambdas_variables : Vec :: new ( ) ,
498498 }
499499 }
500500}
@@ -550,50 +550,58 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> {
550550 } ) )
551551 }
552552
553- fn with_lambda_parameters (
553+ fn push_lambda_parameters (
554554 & mut self ,
555555 lambda_parameters : Vec < Field > ,
556- ) -> datafusion:: common:: Result < Self > {
557- let mut lambdas_variables = self . lambdas_variables . clone ( ) ;
558-
559- for ( _field_idx, hops_out, _type) in lambdas_variables. values_mut ( ) {
560- * hops_out += 1 ;
556+ ) -> datafusion:: common:: Result < ( ) > {
557+ let vars = lambda_parameters
558+ . into_iter ( )
559+ . enumerate ( )
560+ . map ( |( field_idx, field) | {
561+ Ok ( (
562+ field. name ( ) . clone ( ) ,
563+ (
564+ field_idx,
565+ to_substrait_type_from_field ( self , & Arc :: new ( field) ) ?,
566+ ) ,
567+ ) )
568+ } )
569+ . collect :: < datafusion:: common:: Result < _ > > ( ) ?;
570+
571+ self . lambdas_variables . push ( vars) ;
572+
573+ Ok ( ( ) )
574+ }
575+
576+ fn pop_lambda_parameters ( & mut self ) -> datafusion:: common:: Result < ( ) > {
577+ match self . lambdas_variables . pop ( ) {
578+ Some ( _) => Ok ( ( ) ) ,
579+ None => substrait_err ! ( "no lambda_parameters to pop" ) ,
561580 }
581+ }
562582
563- for ( field_idx, field) in lambda_parameters. iter ( ) . enumerate ( ) {
564- let hops_out = 0 ;
565-
566- lambdas_variables. insert (
567- field. name ( ) . clone ( ) ,
568- (
569- field_idx as i32 ,
570- hops_out,
571- to_substrait_type_from_field ( self , & Arc :: new ( field. clone ( ) ) ) ?,
572- ) ,
573- ) ;
583+ fn lambda_variable ( & self , name : & str ) -> datafusion:: common:: Result < ( u32 , i32 ) > {
584+ for ( steps_out, lambda_parameters) in
585+ self . lambdas_variables . iter ( ) . rev ( ) . enumerate ( )
586+ {
587+ if let Some ( ( field_idx, _type) ) = lambda_parameters. get ( name) {
588+ return Ok ( ( steps_out as u32 , * field_idx as i32 ) ) ;
589+ }
574590 }
575591
576- Ok ( Self {
577- extensions : self . extensions . clone ( ) ,
578- serializer_registry : self . serializer_registry ,
579- lambdas_variables,
580- } )
581- }
582-
583- fn lambda_variable ( & self , name : & str ) -> datafusion:: common:: Result < ( i32 , u32 ) > {
584- self . lambdas_variables
585- . get ( name)
586- . map ( |( field, steps_out, _type) | ( * field, * steps_out) )
587- . ok_or_else ( || substrait_datafusion_err ! ( "unknow lambda variable {name}" ) )
592+ substrait_err ! ( "unknow lambda variable {name}" )
588593 }
589594
590595 fn lambda_parameter_type (
591596 & self ,
592597 name : & str ,
593598 ) -> datafusion:: common:: Result < substrait:: proto:: Type > {
594- self . lambdas_variables
595- . get ( name)
596- . map ( |( _field, _steps_out, type_) | type_. clone ( ) )
597- . ok_or_else ( || substrait_datafusion_err ! ( "unknow lambda variable {name}" ) )
599+ for lambda_parameters in self . lambdas_variables . iter ( ) . rev ( ) {
600+ if let Some ( ( _field_idx, type_) ) = lambda_parameters. get ( name) {
601+ return Ok ( type_. clone ( ) ) ;
602+ }
603+ }
604+
605+ substrait_err ! ( "unknow lambda variable {name}" )
598606 }
599607}
0 commit comments