Skip to content

Commit a1abb74

Browse files
committed
fix lambda substrait
1 parent 24a2249 commit a1abb74

3 files changed

Lines changed: 76 additions & 53 deletions

File tree

datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ pub async fn from_scalar_function(
3030
f: &ScalarFunction,
3131
input_schema: &DFSchema,
3232
) -> Result<Expr> {
33-
//TODO: handle lambda functions, as they are also encoded as scalar functions
3433
let Some(fn_signature) = consumer
3534
.get_extensions()
3635
.functions

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use crate::logical_plan::producer::{SubstraitProducer, to_substrait_literal_expr};
19-
use datafusion::common::{DFSchemaRef, ScalarValue, not_impl_err};
19+
use datafusion::common::{DFSchemaRef, ScalarValue, not_impl_err, substrait_err};
2020
use datafusion::logical_expr::{Between, BinaryExpr, Expr, Like, Operator, expr};
2121
use substrait::proto::expression::{RexType, ScalarFunction};
2222
use substrait::proto::function_argument::ArgType;
@@ -39,15 +39,31 @@ pub fn from_lambda_function(
3939

4040
let arguments = std::iter::zip(&fun.args, lambdas_parameters)
4141
.map(|(arg, lambda_parameters)| {
42-
let arg = match lambda_parameters {
43-
Some(lambda_parameters) => {
44-
let mut producer =
45-
producer.with_lambda_parameters(lambda_parameters)?;
42+
let arg = match (arg, lambda_parameters) {
43+
(Expr::Lambda(l), Some(lambda_parameters)) => {
44+
let named_lambda_parameters =
45+
std::iter::zip(&l.params, lambda_parameters)
46+
.map(|(name, parameter)| parameter.with_name(name))
47+
.collect();
4648

47-
producer.handle_expr(arg, schema)?
49+
producer.push_lambda_parameters(named_lambda_parameters)?;
50+
51+
let arg = producer.handle_expr(arg, schema)?;
52+
53+
producer.pop_lambda_parameters()?;
54+
55+
Ok(arg)
4856
}
49-
None => producer.handle_expr(arg, schema)?,
50-
};
57+
(Expr::Lambda(_), None) => substrait_err!(
58+
"{} lambdas_parameters returned None for a lambda",
59+
fun.name()
60+
),
61+
(_, Some(_)) => substrait_err!(
62+
"{} lambdas_parameters returned Some for a value",
63+
fun.name()
64+
),
65+
(_, None) => producer.handle_expr(arg, schema),
66+
}?;
5167

5268
Ok(FunctionArgument {
5369
arg_type: Some(ArgType::Value(arg)),

datafusion/substrait/src/logical_plan/producer/substrait_producer.rs

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ use crate::logical_plan::producer::{
2929
to_substrait_type_from_field,
3030
};
3131
use datafusion::arrow::datatypes::Field;
32-
use datafusion::common::{
33-
Column, DFSchemaRef, HashMap, ScalarValue, substrait_datafusion_err, substrait_err,
34-
};
32+
use datafusion::common::{Column, DFSchemaRef, HashMap, ScalarValue, substrait_err};
3533
use datafusion::execution::SessionState;
3634
use datafusion::execution::registry::SerializerRegistry;
3735
use datafusion::logical_expr::Subquery;
@@ -52,7 +50,7 @@ use substrait::proto::expression::field_reference::{
5250
use substrait::proto::expression::reference_segment::{self, StructField};
5351
use substrait::proto::expression::{FieldReference, ReferenceSegment, RexType};
5452
use substrait::proto::rel::RelType;
55-
use substrait::proto::r#type::Struct;
53+
use substrait::proto::r#type::{Nullability, Struct};
5654
use substrait::proto::{
5755
Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, Rel,
5856
};
@@ -426,12 +424,14 @@ pub trait SubstraitProducer: Send + Sync + Sized {
426424
from_lambda_variable(self, lambda_variable, schema)
427425
}
428426

429-
fn with_lambda_parameters(
427+
fn push_lambda_parameters(
430428
&mut self,
431429
lambda_parameters: Vec<Field>,
432-
) -> datafusion::common::Result<Self>;
430+
) -> datafusion::common::Result<()>;
431+
432+
fn pop_lambda_parameters(&mut self) -> datafusion::common::Result<()>;
433433

434-
fn lambda_variable(&self, name: &str) -> datafusion::common::Result<(i32, u32)>;
434+
fn lambda_variable(&self, name: &str) -> datafusion::common::Result<(u32, i32)>;
435435

436436
fn lambda_parameter_type(
437437
&self,
@@ -444,7 +444,7 @@ fn from_lambda_variable(
444444
lambda_variable: &LambdaVariable,
445445
_schema: &datafusion::common::DFSchema,
446446
) -> Result<Expression, datafusion::error::DataFusionError> {
447-
let (field, steps_out) = producer.lambda_variable(&lambda_variable.name)?;
447+
let (steps_out, field) = producer.lambda_variable(&lambda_variable.name)?;
448448

449449
Ok(Expression {
450450
rex_type: Some(RexType::Selection(Box::new(FieldReference {
@@ -469,7 +469,7 @@ fn from_lambda(
469469
rex_type: Some(RexType::Lambda(Box::new(
470470
substrait::proto::expression::Lambda {
471471
parameters: Some(Struct {
472-
nullability: 1,
472+
nullability: Nullability::Required as i32,
473473
type_variation_reference: 0,
474474
types: lambda
475475
.params
@@ -486,15 +486,15 @@ fn from_lambda(
486486
pub struct DefaultSubstraitProducer<'a> {
487487
extensions: Extensions,
488488
serializer_registry: &'a dyn SerializerRegistry,
489-
lambdas_variables: HashMap<String, (i32, u32, substrait::proto::Type)>,
489+
lambdas_variables: Vec<HashMap<String, (usize, substrait::proto::Type)>>,
490490
}
491491

492492
impl<'a> DefaultSubstraitProducer<'a> {
493493
pub fn new(state: &'a SessionState) -> Self {
494494
DefaultSubstraitProducer {
495495
extensions: Extensions::default(),
496496
serializer_registry: state.serializer_registry().as_ref(),
497-
lambdas_variables: HashMap::new(),
497+
lambdas_variables: Vec::new(),
498498
}
499499
}
500500
}
@@ -550,50 +550,58 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> {
550550
}))
551551
}
552552

553-
fn with_lambda_parameters(
553+
fn push_lambda_parameters(
554554
&mut self,
555555
lambda_parameters: Vec<Field>,
556-
) -> datafusion::common::Result<Self> {
557-
let mut lambdas_variables = self.lambdas_variables.clone();
558-
559-
for (_field_idx, hops_out, _type) in lambdas_variables.values_mut() {
560-
*hops_out += 1;
556+
) -> datafusion::common::Result<()> {
557+
let vars = lambda_parameters
558+
.into_iter()
559+
.enumerate()
560+
.map(|(field_idx, field)| {
561+
Ok((
562+
field.name().clone(),
563+
(
564+
field_idx,
565+
to_substrait_type_from_field(self, &Arc::new(field))?,
566+
),
567+
))
568+
})
569+
.collect::<datafusion::common::Result<_>>()?;
570+
571+
self.lambdas_variables.push(vars);
572+
573+
Ok(())
574+
}
575+
576+
fn pop_lambda_parameters(&mut self) -> datafusion::common::Result<()> {
577+
match self.lambdas_variables.pop() {
578+
Some(_) => Ok(()),
579+
None => substrait_err!("no lambda_parameters to pop"),
561580
}
581+
}
562582

563-
for (field_idx, field) in lambda_parameters.iter().enumerate() {
564-
let hops_out = 0;
565-
566-
lambdas_variables.insert(
567-
field.name().clone(),
568-
(
569-
field_idx as i32,
570-
hops_out,
571-
to_substrait_type_from_field(self, &Arc::new(field.clone()))?,
572-
),
573-
);
583+
fn lambda_variable(&self, name: &str) -> datafusion::common::Result<(u32, i32)> {
584+
for (steps_out, lambda_parameters) in
585+
self.lambdas_variables.iter().rev().enumerate()
586+
{
587+
if let Some((field_idx, _type)) = lambda_parameters.get(name) {
588+
return Ok((steps_out as u32, *field_idx as i32));
589+
}
574590
}
575591

576-
Ok(Self {
577-
extensions: self.extensions.clone(),
578-
serializer_registry: self.serializer_registry,
579-
lambdas_variables,
580-
})
581-
}
582-
583-
fn lambda_variable(&self, name: &str) -> datafusion::common::Result<(i32, u32)> {
584-
self.lambdas_variables
585-
.get(name)
586-
.map(|(field, steps_out, _type)| (*field, *steps_out))
587-
.ok_or_else(|| substrait_datafusion_err!("unknow lambda variable {name}"))
592+
substrait_err!("unknow lambda variable {name}")
588593
}
589594

590595
fn lambda_parameter_type(
591596
&self,
592597
name: &str,
593598
) -> datafusion::common::Result<substrait::proto::Type> {
594-
self.lambdas_variables
595-
.get(name)
596-
.map(|(_field, _steps_out, type_)| type_.clone())
597-
.ok_or_else(|| substrait_datafusion_err!("unknow lambda variable {name}"))
599+
for lambda_parameters in self.lambdas_variables.iter().rev() {
600+
if let Some((_field_idx, type_)) = lambda_parameters.get(name) {
601+
return Ok(type_.clone());
602+
}
603+
}
604+
605+
substrait_err!("unknow lambda variable {name}")
598606
}
599607
}

0 commit comments

Comments
 (0)