Skip to content

Commit df1b740

Browse files
committed
feat: add substrait lambda support
1 parent 42cd2fa commit df1b740

12 files changed

Lines changed: 866 additions & 29 deletions

File tree

datafusion/substrait/src/extensions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use substrait::proto::extensions::simple_extension_declaration::{
2727
/// types. This structs facilitates the use of these extensions in DataFusion.
2828
/// TODO: DF doesn't yet use extensions for type variations <https://github.com/apache/datafusion/issues/11544>
2929
/// TODO: DF doesn't yet provide valid extensionUris <https://github.com/apache/datafusion/issues/11545>
30-
#[derive(Default, Debug, PartialEq)]
30+
#[derive(Clone, Default, Debug, PartialEq)]
3131
pub struct Extensions {
3232
pub functions: HashMap<u32, String>, // anchor -> function name
3333
pub types: HashMap<u32, String>, // anchor -> type name

datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use datafusion::logical_expr::Expr;
2121
use std::sync::Arc;
2222
use substrait::proto::expression::FieldReference;
2323
use substrait::proto::expression::field_reference::ReferenceType::DirectReference;
24-
use substrait::proto::expression::field_reference::RootType;
24+
use substrait::proto::expression::field_reference::{LambdaParameterReference, RootType};
2525
use substrait::proto::expression::reference_segment::ReferenceType::StructField;
2626

2727
pub async fn from_field_reference(
@@ -56,9 +56,9 @@ pub(crate) fn from_substrait_field_reference(
5656
Some(RootType::Expression(_)) => not_impl_err!(
5757
"Expression root type in field reference is not supported"
5858
),
59-
Some(RootType::LambdaParameterReference(_)) => not_impl_err!(
60-
"Lambda parameter reference in field reference is not yet supported"
61-
),
59+
Some(RootType::LambdaParameterReference(
60+
LambdaParameterReference { steps_out },
61+
)) => consumer.lambda_variable(*steps_out as usize, field_idx),
6262
}
6363
}
6464
_ => not_impl_err!(
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::{
19+
common::{DFSchema, substrait_err},
20+
prelude::{Expr, lambda},
21+
};
22+
use substrait::proto;
23+
24+
use crate::logical_plan::consumer::SubstraitConsumer;
25+
26+
pub async fn from_lambda(
27+
consumer: &impl SubstraitConsumer,
28+
expr: &proto::expression::Lambda,
29+
input_schema: &DFSchema,
30+
) -> datafusion::common::Result<Expr> {
31+
let Some(parameters) = expr.parameters.as_ref() else {
32+
return substrait_err!("Lambda expression without parameters is not allowed");
33+
};
34+
35+
let (names, consumer_with_parameters) =
36+
consumer.with_lambda_parameters(&parameters.types, input_schema)?;
37+
38+
let Some(body) = expr.body.as_ref() else {
39+
return substrait_err!("Lambda expression without body is not allowed");
40+
};
41+
42+
let body = consumer_with_parameters
43+
.consume_expression(body, input_schema)
44+
.await?;
45+
46+
Ok(lambda(names, body))
47+
}

datafusion/substrait/src/logical_plan/consumer/expr/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod cast;
2020
mod field_reference;
2121
mod function_arguments;
2222
mod if_then;
23+
mod lambda;
2324
mod literal;
2425
mod nested;
2526
mod scalar_function;
@@ -32,6 +33,7 @@ pub use cast::*;
3233
pub use field_reference::*;
3334
pub use function_arguments::*;
3435
pub use if_then::*;
36+
pub use lambda::*;
3537
pub use literal::*;
3638
pub use nested::*;
3739
pub use scalar_function::*;
@@ -95,8 +97,11 @@ pub async fn from_substrait_rex(
9597
RexType::DynamicParameter(expr) => {
9698
consumer.consume_dynamic_parameter(expr, input_schema).await
9799
}
98-
RexType::Lambda(_) | RexType::LambdaInvocation(_) => {
99-
not_impl_err!("Lambda expressions are not yet supported")
100+
RexType::Lambda(lambda) => {
101+
consumer.consume_lambda(lambda.as_ref(), input_schema).await
102+
}
103+
RexType::LambdaInvocation(_) => {
104+
not_impl_err!("Lambda invocations are not supported")
100105
}
101106
},
102107
None => substrait_err!("Expression must set rex_type: {expression:?}"),

datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ pub async fn from_scalar_function(
4545
let fn_name = substrait_fun_name(fn_signature);
4646
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
4747

48+
let udlf_func = consumer
49+
.get_function_registry()
50+
.higher_order_function(fn_name)
51+
.or_else(|e| {
52+
if let Some(alt_name) = substrait_to_df_name(fn_name) {
53+
consumer
54+
.get_function_registry()
55+
.higher_order_function(alt_name)
56+
.or(Err(e))
57+
} else {
58+
Err(e)
59+
}
60+
});
61+
4862
let udf_func = consumer.get_function_registry().udf(fn_name).or_else(|e| {
4963
if let Some(alt_name) = substrait_to_df_name(fn_name) {
5064
consumer.get_function_registry().udf(alt_name).or(Err(e))
@@ -53,9 +67,14 @@ pub async fn from_scalar_function(
5367
}
5468
});
5569

56-
// try to first match the requested function into registered udfs, then built-in ops
70+
// try to first match the requested function into registered udlfs, then udfs, built-in ops
5771
// and finally built-in expressions
58-
if let Ok(func) = udf_func {
72+
if let Ok(func) = udlf_func {
73+
Ok(Expr::HigherOrderFunction(expr::HigherOrderFunction::new(
74+
func.to_owned(),
75+
args,
76+
)))
77+
} else if let Ok(func) = udf_func {
5978
Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
6079
func.to_owned(),
6180
args,

0 commit comments

Comments
 (0)