Skip to content

Commit fa6ccdc

Browse files
committed
add lambda substrait support
1 parent 6f2c92b commit fa6ccdc

8 files changed

Lines changed: 284 additions & 18 deletions

File tree

datafusion/substrait/src/extensions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use substrait::proto::extensions::simple_extension_declaration::{
2727
/// types. This structs facilitates the use of these extensions in DataFusion.
2828
/// TODO: DF doesn't yet use extensions for type variations <https://github.com/apache/datafusion/issues/11544>
2929
/// TODO: DF doesn't yet provide valid extensionUris <https://github.com/apache/datafusion/issues/11545>
30-
#[derive(Default, Debug, PartialEq)]
30+
#[derive(Clone, Default, Debug, PartialEq)]
3131
pub struct Extensions {
3232
pub functions: HashMap<u32, String>, // anchor -> function name
3333
pub types: HashMap<u32, String>, // anchor -> type name

datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use datafusion::logical_expr::Expr;
2121
use std::sync::Arc;
2222
use substrait::proto::expression::FieldReference;
2323
use substrait::proto::expression::field_reference::ReferenceType::DirectReference;
24-
use substrait::proto::expression::field_reference::RootType;
24+
use substrait::proto::expression::field_reference::{LambdaParameterReference, RootType};
2525
use substrait::proto::expression::reference_segment::ReferenceType::StructField;
2626

2727
pub async fn from_field_reference(
@@ -56,9 +56,9 @@ pub(crate) fn from_substrait_field_reference(
5656
Some(RootType::Expression(_)) => not_impl_err!(
5757
"Expression root type in field reference is not supported"
5858
),
59-
Some(RootType::LambdaParameterReference(_)) => not_impl_err!(
60-
"Lambda parameter reference in field reference is not yet supported"
61-
),
59+
Some(RootType::LambdaParameterReference(
60+
LambdaParameterReference { steps_out },
61+
)) => consumer.lambda_variable(*steps_out as usize, field_idx),
6262
}
6363
}
6464
_ => not_impl_err!(

datafusion/substrait/src/logical_plan/consumer/expr/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,11 @@ pub async fn from_substrait_rex(
9393
RexType::DynamicParameter(expr) => {
9494
consumer.consume_dynamic_parameter(expr, input_schema).await
9595
}
96-
RexType::Lambda(_) | RexType::LambdaInvocation(_) => {
97-
not_impl_err!("Lambda expressions are not yet supported")
96+
RexType::Lambda(lambda) => {
97+
consumer.consume_lambda(lambda.as_ref(), input_schema).await
98+
}
99+
RexType::LambdaInvocation(_) => {
100+
not_impl_err!("Lambda invocations are not supported")
98101
}
99102
},
100103
None => substrait_err!("Expression must set rex_type: {expression:?}"),

datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ pub async fn from_scalar_function(
4545
let fn_name = substrait_fun_name(fn_signature);
4646
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
4747

48+
let udlf_func = consumer.get_function_registry().udlf(fn_name).or_else(|e| {
49+
if let Some(alt_name) = substrait_to_df_name(fn_name) {
50+
consumer.get_function_registry().udlf(alt_name).or(Err(e))
51+
} else {
52+
Err(e)
53+
}
54+
});
55+
4856
let udf_func = consumer.get_function_registry().udf(fn_name).or_else(|e| {
4957
if let Some(alt_name) = substrait_to_df_name(fn_name) {
5058
consumer.get_function_registry().udf(alt_name).or(Err(e))
@@ -53,9 +61,14 @@ pub async fn from_scalar_function(
5361
}
5462
});
5563

56-
// try to first match the requested function into registered udfs, then built-in ops
64+
// try to first match the requested function into registered udlfs, then udfs, built-in ops
5765
// and finally built-in expressions
58-
if let Ok(func) = udf_func {
66+
if let Ok(func) = udlf_func {
67+
Ok(Expr::LambdaFunction(expr::LambdaFunction::new(
68+
func.to_owned(),
69+
args,
70+
)))
71+
} else if let Ok(func) = udf_func {
5972
Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
6073
func.to_owned(),
6174
args,

datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,24 @@ use super::{
2323
from_substrait_rex, from_window_function,
2424
};
2525
use crate::extensions::Extensions;
26+
use crate::logical_plan::consumer::from_substrait_type_without_names;
2627
use async_trait::async_trait;
27-
use datafusion::arrow::datatypes::DataType;
28+
use datafusion::arrow::datatypes::{DataType, Field, FieldRef};
2829
use datafusion::catalog::TableProvider;
2930
use datafusion::common::{
3031
DFSchema, ScalarValue, TableReference, not_impl_err, substrait_err,
3132
};
3233
use datafusion::execution::{FunctionRegistry, SessionState};
3334
use datafusion::logical_expr::{Expr, Extension, LogicalPlan};
35+
use datafusion::prelude::{lambda, lambda_var};
36+
use std::collections::VecDeque;
3437
use std::sync::{Arc, RwLock};
35-
use substrait::proto;
3638
use substrait::proto::expression as substrait_expression;
3739
use substrait::proto::expression::{
3840
Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction,
3941
SingularOrList, SwitchExpression, WindowFunction,
4042
};
43+
use substrait::proto::{self, Type};
4144
use substrait::proto::{
4245
AggregateRel, ConsistentPartitionWindowRel, CrossRel, DynamicParameter, ExchangeRel,
4346
Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, FetchRel,
@@ -372,6 +375,14 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
372375
not_impl_err!("Dynamic Parameter expression not supported")
373376
}
374377

378+
async fn consume_lambda(
379+
&self,
380+
expr: &proto::expression::Lambda,
381+
input_schema: &DFSchema,
382+
) -> datafusion::common::Result<Expr> {
383+
from_lambda(self, expr, input_schema).await
384+
}
385+
375386
// Outer Schema Stack
376387
// These methods manage a stack of outer schemas for correlated subquery support.
377388
// When entering a subquery, the enclosing query's schema is pushed onto the stack.
@@ -469,6 +480,32 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
469480
};
470481
substrait_err!("Missing handler for user-defined literals {}", type_ref)
471482
}
483+
484+
fn with_lambda_parameters(
485+
&self,
486+
lambda_parameters: &[Type],
487+
) -> datafusion::common::Result<(Vec<String>, Self)>;
488+
489+
fn lambda_variable(
490+
&self,
491+
steps_out: usize,
492+
field_idx: usize,
493+
) -> datafusion::common::Result<Expr>;
494+
}
495+
496+
async fn from_lambda(
497+
consumer: &impl SubstraitConsumer,
498+
expr: &proto::expression::Lambda,
499+
input_schema: &DFSchema,
500+
) -> datafusion::common::Result<Expr> {
501+
let parameters = expr.parameters.as_ref().unwrap();
502+
503+
let (names, consumer) = consumer.with_lambda_parameters(&parameters.types)?;
504+
505+
let body = expr.body.as_ref().unwrap();
506+
let body = consumer.consume_expression(body, input_schema).await?;
507+
508+
Ok(lambda(names, body))
472509
}
473510

474511
/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions.
@@ -478,6 +515,8 @@ pub struct DefaultSubstraitConsumer<'a> {
478515
pub(super) extensions: &'a Extensions,
479516
pub(super) state: &'a SessionState,
480517
outer_schemas: RwLock<Vec<Arc<DFSchema>>>,
518+
lambdas_parameters: VecDeque<Vec<FieldRef>>,
519+
num_lambda_parameters: usize,
481520
}
482521

483522
impl<'a> DefaultSubstraitConsumer<'a> {
@@ -486,12 +525,61 @@ impl<'a> DefaultSubstraitConsumer<'a> {
486525
extensions,
487526
state,
488527
outer_schemas: RwLock::new(Vec::new()),
528+
lambdas_parameters: VecDeque::new(),
529+
num_lambda_parameters: 0,
489530
}
490531
}
491532
}
492533

493534
#[async_trait]
494535
impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
536+
fn with_lambda_parameters(
537+
&self,
538+
lambda_parameters: &[Type],
539+
) -> datafusion::common::Result<(Vec<String>, Self)> {
540+
let mut lambdas_parameters = self.lambdas_parameters.clone();
541+
542+
let lambda_parameters = lambda_parameters
543+
.iter()
544+
.enumerate()
545+
.map(|(i, ty)| {
546+
let dt = from_substrait_type_without_names(self, ty)?;
547+
548+
Ok(Arc::new(Field::new(
549+
format!("p{}", i + self.num_lambda_parameters),
550+
dt,
551+
true,
552+
)))
553+
})
554+
.collect::<datafusion::common::Result<Vec<_>>>()?;
555+
556+
let names = lambda_parameters.iter().map(|f| f.name().clone()).collect();
557+
let num_lambda_parameters = self.num_lambda_parameters + lambda_parameters.len();
558+
559+
lambdas_parameters.push_front(lambda_parameters);
560+
561+
Ok((
562+
names,
563+
Self {
564+
extensions: self.extensions,
565+
state: self.state,
566+
outer_schemas: RwLock::new(self.outer_schemas.read().unwrap().clone()),
567+
lambdas_parameters,
568+
num_lambda_parameters,
569+
},
570+
))
571+
}
572+
573+
fn lambda_variable(
574+
&self,
575+
steps_out: usize,
576+
field_idx: usize,
577+
) -> datafusion::common::Result<Expr> {
578+
let var = &self.lambdas_parameters[steps_out][field_idx];
579+
580+
Ok(lambda_var(var.name(), Arc::clone(var)))
581+
}
582+
495583
async fn resolve_table_ref(
496584
&self,
497585
table_ref: &TableReference,

datafusion/substrait/src/logical_plan/producer/expr/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,8 @@ pub fn to_substrait_rex(
150150
}
151151
Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
152152
Expr::LambdaFunction(expr) => producer.handle_lambda_function(expr, schema),
153-
Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
154-
Expr::LambdaVariable(expr) => {
155-
not_impl_err!("Cannot convert {expr:?} to Substrait")
156-
}
153+
Expr::Lambda(expr) => producer.handle_lambda(expr, schema),
154+
Expr::LambdaVariable(expr) => producer.handle_lambda_variable(expr, schema),
157155
}
158156
}
159157

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,37 @@ pub fn from_lambda_function(
3535
fun: &expr::LambdaFunction,
3636
schema: &DFSchemaRef,
3737
) -> datafusion::common::Result<Expression> {
38-
from_function(producer, fun.name(), &fun.args, schema)
38+
let lambdas_parameters = fun.lambdas_parameters(schema)?;
39+
40+
let arguments = std::iter::zip(&fun.args, lambdas_parameters)
41+
.map(|(arg, lambda_parameters)| {
42+
let arg = match lambda_parameters {
43+
Some(lambda_parameters) => {
44+
let mut producer =
45+
producer.with_lambda_parameters(lambda_parameters)?;
46+
47+
producer.handle_expr(arg, schema)?
48+
}
49+
None => producer.handle_expr(arg, schema)?,
50+
};
51+
52+
Ok(FunctionArgument {
53+
arg_type: Some(ArgType::Value(arg)),
54+
})
55+
})
56+
.collect::<datafusion::common::Result<_>>()?;
57+
58+
let function_anchor = producer.register_function(fun.name().to_string());
59+
#[expect(deprecated)]
60+
Ok(Expression {
61+
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
62+
function_reference: function_anchor,
63+
arguments,
64+
output_type: None,
65+
options: vec![],
66+
args: vec![],
67+
})),
68+
})
3969
}
4070

4171
fn from_function(

0 commit comments

Comments
 (0)