Skip to content

Commit 3eec7a9

Browse files
committed
finish lambda substrait support
1 parent 27479e2 commit 3eec7a9

9 files changed

Lines changed: 602 additions & 128 deletions

File tree

datafusion/core/src/execution/context/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,20 @@ impl SessionContext {
15661566
state.register_udf(Arc::new(f)).ok();
15671567
}
15681568

1569+
/// Registers a higher-order function within this context.
1570+
///
1571+
/// Note in SQL queries, function names are looked up using
1572+
/// lowercase unless the query uses quotes. For example,
1573+
///
1574+
/// - `SELECT MY_HIGHER_ORDER_FUNC(x)...` will look for a function named `"my_higher_order_func"`
1575+
/// - `SELECT "my_HIGHER_ORDER_FUNC"(x)` will look for a function named `"my_HIGHER_ORDER_FUNC"`
1576+
///
1577+
/// Any functions registered with the function name or its aliases will be overwritten with this new function
1578+
pub fn register_higher_order_function(&self, f: Arc<dyn HigherOrderUDF>) {
1579+
let mut state = self.state.write();
1580+
state.register_higher_order_function(f).ok();
1581+
}
1582+
15691583
/// Registers an aggregate UDF within this context.
15701584
///
15711585
/// Note in SQL queries, aggregate names are looked up using
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::{
19+
common::{DFSchema, substrait_err},
20+
prelude::{Expr, lambda},
21+
};
22+
use substrait::proto;
23+
24+
use crate::logical_plan::consumer::SubstraitConsumer;
25+
26+
pub async fn from_lambda(
27+
consumer: &impl SubstraitConsumer,
28+
expr: &proto::expression::Lambda,
29+
input_schema: &DFSchema,
30+
) -> datafusion::common::Result<Expr> {
31+
let Some(parameters) = expr.parameters.as_ref() else {
32+
return substrait_err!("Lambda expression without parameters is not allowed");
33+
};
34+
35+
let (names, consumer_with_parameters) =
36+
consumer.with_lambda_parameters(&parameters.types, input_schema)?;
37+
38+
let Some(body) = expr.body.as_ref() else {
39+
return substrait_err!("Lambda expression without body is not allowed");
40+
};
41+
42+
let body = consumer_with_parameters
43+
.consume_expression(body, input_schema)
44+
.await?;
45+
46+
Ok(lambda(names, body))
47+
}

datafusion/substrait/src/logical_plan/consumer/expr/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod cast;
2020
mod field_reference;
2121
mod function_arguments;
2222
mod if_then;
23+
mod lambda;
2324
mod literal;
2425
mod nested;
2526
mod scalar_function;
@@ -32,6 +33,7 @@ pub use cast::*;
3233
pub use field_reference::*;
3334
pub use function_arguments::*;
3435
pub use if_then::*;
36+
pub use lambda::*;
3537
pub use literal::*;
3638
pub use nested::*;
3739
pub use scalar_function::*;

datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,19 @@ pub async fn from_scalar_function(
4444
let fn_name = substrait_fun_name(fn_signature);
4545
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
4646

47-
let udlf_func = consumer.get_function_registry().udlf(fn_name).or_else(|e| {
48-
if let Some(alt_name) = substrait_to_df_name(fn_name) {
49-
consumer.get_function_registry().udlf(alt_name).or(Err(e))
50-
} else {
51-
Err(e)
52-
}
53-
});
47+
let udlf_func = consumer
48+
.get_function_registry()
49+
.higher_order_function(fn_name)
50+
.or_else(|e| {
51+
if let Some(alt_name) = substrait_to_df_name(fn_name) {
52+
consumer
53+
.get_function_registry()
54+
.higher_order_function(alt_name)
55+
.or(Err(e))
56+
} else {
57+
Err(e)
58+
}
59+
});
5460

5561
let udf_func = consumer.get_function_registry().udf(fn_name).or_else(|e| {
5662
if let Some(alt_name) = substrait_to_df_name(fn_name) {
@@ -63,7 +69,7 @@ pub async fn from_scalar_function(
6369
// try to first match the requested function into registered udlfs, then udfs, built-in ops
6470
// and finally built-in expressions
6571
if let Ok(func) = udlf_func {
66-
Ok(Expr::LambdaFunction(expr::LambdaFunction::new(
72+
Ok(Expr::HigherOrderFunction(expr::HigherOrderFunction::new(
6773
func.to_owned(),
6874
args,
6975
)))

0 commit comments

Comments
 (0)