Skip to content

Commit 90eb08f

Browse files
committed
improve lambdas
1 parent 2be9e54 commit 90eb08f

9 files changed

Lines changed: 343 additions & 257 deletions

File tree

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3411,7 +3411,7 @@ impl Display for SchemaDisplay<'_> {
34113411
)
34123412
}
34133413
Expr::LambdaVariable(c) => {
3414-
write!(f, "{}", c.name)
3414+
f.write_str(&c.name)
34153415
}
34163416
}
34173417
}

datafusion/expr/src/udlf.rs

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@
1919
2020
use crate::expr::schema_name_from_exprs_comma_separated_without_space;
2121
use crate::simplify::{ExprSimplifyResult, SimplifyContext};
22-
use crate::sort_properties::{ExprProperties, SortProperties};
2322
use crate::{ColumnarValue, Documentation, Expr};
2423
use arrow::array::{ArrayRef, RecordBatch};
2524
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
2625
use datafusion_common::config::ConfigOptions;
2726
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
2827
use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
29-
use datafusion_expr_common::interval_arithmetic::Interval;
3028
use datafusion_expr_common::signature::Volatility;
3129
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
3230
use std::any::Any;
@@ -602,83 +600,6 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync {
602600
}
603601
}
604602

605-
/// Computes the output [`Interval`] for a [`LambdaUDF`], given the input
606-
/// intervals.
607-
///
608-
/// # Parameters
609-
///
610-
/// * `children` are the intervals for the children (inputs) of this function.
611-
///
612-
/// # Example
613-
///
614-
/// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
615-
/// then the output interval would be `[0, 3]`.
616-
fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
617-
// We cannot assume the input datatype is the same of output type.
618-
Interval::make_unbounded(&DataType::Null)
619-
}
620-
621-
/// Updates bounds for child expressions, given a known [`Interval`]s for this
622-
/// function.
623-
///
624-
/// This function is used to propagate constraints down through an
625-
/// expression tree.
626-
///
627-
/// # Parameters
628-
///
629-
/// * `interval` is the currently known interval for this function.
630-
/// * `inputs` are the current intervals for the inputs (children) of this function.
631-
///
632-
/// # Returns
633-
///
634-
/// A `Vec` of new intervals for the children, in order.
635-
///
636-
/// If constraint propagation reveals an infeasibility for any child, returns
637-
/// [`None`]. If none of the children intervals change as a result of
638-
/// propagation, may return an empty vector instead of cloning `children`.
639-
/// This is the default (and conservative) return value.
640-
///
641-
/// # Example
642-
///
643-
/// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
644-
/// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
645-
fn propagate_constraints(
646-
&self,
647-
_interval: &Interval,
648-
_inputs: &[&Interval],
649-
) -> Result<Option<Vec<Interval>>> {
650-
Ok(Some(vec![]))
651-
}
652-
653-
/// Calculates the [`SortProperties`] of this function based on its children's properties.
654-
fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
655-
if !self.preserves_lex_ordering(inputs)? {
656-
return Ok(SortProperties::Unordered);
657-
}
658-
659-
let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
660-
return Ok(SortProperties::Singleton);
661-
};
662-
663-
if inputs
664-
.iter()
665-
.skip(1)
666-
.all(|input| &input.sort_properties == first_order)
667-
{
668-
Ok(*first_order)
669-
} else {
670-
Ok(SortProperties::Unordered)
671-
}
672-
}
673-
674-
/// Returns true if the function preserves lexicographical ordering based on
675-
/// the input ordering.
676-
///
677-
/// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not.
678-
fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
679-
Ok(false)
680-
}
681-
682603
/// Coerce arguments of a function call to types that the function can evaluate.
683604
///
684605
/// See the [type coercion module](crate::type_coercion)

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion_common::alias::AliasGenerator;
3030
use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes};
3131
use datafusion_common::tree_node::{Transformed, TreeNode};
3232
use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name};
33-
use datafusion_expr::expr::{Alias, ScalarFunction};
33+
use datafusion_expr::expr::{Alias, LambdaFunction, ScalarFunction};
3434
use datafusion_expr::logical_plan::{
3535
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
3636
};
@@ -651,12 +651,15 @@ impl CSEController for ExprCSEController<'_> {
651651

652652
fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
653653
match node {
654-
// In case of `ScalarFunction`s we don't know which children are surely
654+
// In case of `ScalarFunction`s and `LambdaFunction`s we don't know which children are surely
655655
// executed so start visiting all children conditionally and stop the
656656
// recursion with `TreeNodeRecursion::Jump`.
657657
Expr::ScalarFunction(ScalarFunction { func, args }) => {
658658
func.conditional_arguments(args)
659659
}
660+
Expr::LambdaFunction(LambdaFunction { func, args }) => {
661+
func.conditional_arguments(args)
662+
}
660663

661664
// In case of `And` and `Or` the first child is surely executed, but we
662665
// account subexpressions as conditional in the second.
@@ -696,7 +699,8 @@ impl CSEController for ExprCSEController<'_> {
696699
}
697700

698701
fn is_valid(node: &Expr) -> bool {
699-
!node.is_volatile_node() && !matches!(node, Expr::LambdaVariable(_))
702+
!node.is_volatile_node()
703+
&& !matches!(node, Expr::Lambda(_) | Expr::LambdaVariable(_))
700704
}
701705

702706
fn is_ignored(&self, node: &Expr) -> bool {
@@ -726,6 +730,7 @@ impl CSEController for ExprCSEController<'_> {
726730
| Expr::ScalarVariable(..)
727731
| Expr::Alias(..)
728732
| Expr::Wildcard { .. }
733+
| Expr::Lambda(_)
729734
| Expr::LambdaVariable(_)
730735
);
731736

datafusion/physical-expr/src/expressions/lambda.rs

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ use arrow::{
2727
datatypes::{DataType, Schema},
2828
record_batch::RecordBatch,
2929
};
30-
use datafusion_common::{internal_err, tree_node::TreeNodeVisitor, HashSet, Result};
30+
use datafusion_common::{HashSet, Result, internal_err, tree_node::TreeNodeVisitor};
3131
use datafusion_common::{
3232
plan_err,
3333
tree_node::{TreeNode, TreeNodeRecursion},
3434
};
3535
use datafusion_expr::ColumnarValue;
36-
use hashbrown::{hash_map::EntryRef, HashMap};
36+
use hashbrown::{HashMap, hash_map::EntryRef};
3737

3838
/// Represents a lambda with the given parameters names and body
3939
#[derive(Debug, Eq, Clone)]
@@ -100,6 +100,14 @@ impl LambdaExpr {
100100
&self.captured_columns
101101
}
102102

103+
/// Returns lambdas variables names that aren't of this lambda nor any other lambda down tree.
104+
/// Example:
105+
///
106+
/// `array_transform([[[1, 2, 3]]], a -> array_transform(a, b -> array_transform(b, c -> length(a) + length(b) + c)))`
107+
///
108+
/// For the outermost lambda, this would return an empty hash set
109+
/// For the middle one, `HashSet("a")`
110+
/// And for the innermost, `HashSet("a", "b")`
103111
pub(crate) fn captured_variables(&self) -> &HashSet<String> {
104112
&self.captured_variables
105113
}
@@ -192,7 +200,7 @@ impl<'n> TreeNodeVisitor<'n> for Captures<'n> {
192200
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
193201
if let Some(lambda) = node.as_any().downcast_ref::<LambdaExpr>() {
194202
for param in &lambda.params {
195-
*self.shadows.entry_ref(param.as_str()).or_default() += 1;
203+
*self.shadows.entry(param.as_str()).or_default() += 1;
196204
}
197205
} else if let Some(lambda_variable) =
198206
node.as_any().downcast_ref::<LambdaVariable>()
@@ -230,3 +238,101 @@ impl<'n> TreeNodeVisitor<'n> for Captures<'n> {
230238
Ok(TreeNodeRecursion::Continue)
231239
}
232240
}
241+
242+
#[cfg(test)]
243+
mod tests {
244+
use crate::{
245+
LambdaFunctionExpr,
246+
expressions::{Column, LambdaExpr, NoOp, lambda::lambda, lambda_variable},
247+
};
248+
use arrow::{
249+
array::RecordBatch,
250+
datatypes::{DataType, Field, FieldRef, Schema},
251+
};
252+
use datafusion_common::{HashSet, Result};
253+
use datafusion_expr::{ColumnarValue, LambdaUDF};
254+
use std::sync::Arc;
255+
256+
#[derive(Debug, Hash, Eq, PartialEq)]
257+
struct DummyLambdaUDF;
258+
259+
impl LambdaUDF for DummyLambdaUDF {
260+
fn as_any(&self) -> &dyn std::any::Any {
261+
unimplemented!()
262+
}
263+
264+
fn name(&self) -> &str {
265+
"dummy_udlf"
266+
}
267+
268+
fn signature(&self) -> &datafusion_expr::LambdaSignature {
269+
unimplemented!()
270+
}
271+
272+
fn lambdas_parameters(
273+
&self,
274+
_args: &[datafusion_expr::ValueOrLambda<FieldRef, ()>],
275+
) -> Result<Vec<Option<Vec<Field>>>> {
276+
unimplemented!()
277+
}
278+
279+
fn return_field_from_args(
280+
&self,
281+
_args: datafusion_expr::LambdaReturnFieldArgs,
282+
) -> Result<FieldRef> {
283+
unimplemented!()
284+
}
285+
286+
fn invoke_with_args(
287+
&self,
288+
_args: datafusion_expr::LambdaFunctionArgs,
289+
) -> Result<ColumnarValue> {
290+
unimplemented!()
291+
}
292+
}
293+
294+
#[test]
295+
fn test_lambda_captures() {
296+
let null_field = Arc::new(Field::new("", DataType::Null, true));
297+
298+
//`var_b -> dummy_udlf(var_a, var_b, column@0, var_c -> var_c))`
299+
let inner = LambdaExpr::try_new(
300+
vec![String::from("var_b")],
301+
Arc::new(LambdaFunctionExpr::new(
302+
"dummy_udlf",
303+
Arc::new(DummyLambdaUDF),
304+
vec![
305+
lambda_variable("var_a", Arc::clone(&null_field)).unwrap(),
306+
lambda_variable("var_b", Arc::clone(&null_field)).unwrap(),
307+
Arc::new(Column::new("column", 0)),
308+
lambda(
309+
["var_c"],
310+
lambda_variable("var_c", Arc::clone(&null_field)).unwrap(),
311+
)
312+
.unwrap(),
313+
],
314+
Arc::clone(&null_field),
315+
Arc::new(Default::default()),
316+
)),
317+
)
318+
.unwrap();
319+
320+
assert_eq!(inner.captured_columns(), &HashSet::from([0]));
321+
assert_eq!(
322+
inner.captured_variables(),
323+
&HashSet::from([String::from("var_a")])
324+
);
325+
}
326+
327+
#[test]
328+
fn test_lambda_evaluate() {
329+
let lambda = lambda(["a"], Arc::new(NoOp::new())).unwrap();
330+
let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
331+
assert!(lambda.evaluate(&batch).is_err());
332+
}
333+
334+
#[test]
335+
fn test_lambda_duplicate_name() {
336+
assert!(lambda(["a", "a"], Arc::new(NoOp::new())).is_err());
337+
}
338+
}

datafusion/physical-expr/src/expressions/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ mod case;
2323
mod cast;
2424
mod cast_column;
2525
mod column;
26-
mod lambda_variable;
2726
mod dynamic_filters;
2827
mod in_list;
2928
mod is_not_null;
3029
mod is_null;
30+
mod lambda_variable;
3131
mod lambda;
3232
mod like;
3333
mod literal;

datafusion/physical-expr/src/lambda_function.rs

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ use arrow::array::{Array, NullArray, RecordBatch};
4141
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
4242
use datafusion_common::config::{ConfigEntry, ConfigOptions};
4343
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
44-
use datafusion_expr::interval_arithmetic::Interval;
45-
use datafusion_expr::sort_properties::ExprProperties;
4644
use datafusion_expr::type_coercion::functions::value_fields_with_lambda_udf;
4745
use datafusion_expr::{
4846
ColumnarValue, LambdaArgument, LambdaFunctionArgs, LambdaReturnFieldArgs, LambdaUDF,
@@ -243,7 +241,6 @@ fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
243241
}
244242

245243
impl PhysicalExpr for LambdaFunctionExpr {
246-
/// Return a reference to Any that can be used for downcasting
247244
fn as_any(&self) -> &dyn Any {
248245
self
249246
}
@@ -404,34 +401,6 @@ impl PhysicalExpr for LambdaFunctionExpr {
404401
)))
405402
}
406403

407-
fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
408-
self.fun.evaluate_bounds(children)
409-
}
410-
411-
fn propagate_constraints(
412-
&self,
413-
interval: &Interval,
414-
children: &[&Interval],
415-
) -> Result<Option<Vec<Interval>>> {
416-
self.fun.propagate_constraints(interval, children)
417-
}
418-
419-
fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
420-
let sort_properties = self.fun.output_ordering(children)?;
421-
let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
422-
let children_range = children
423-
.iter()
424-
.map(|props| &props.range)
425-
.collect::<Vec<_>>();
426-
let range = self.fun().evaluate_bounds(&children_range)?;
427-
428-
Ok(ExprProperties {
429-
sort_properties,
430-
range,
431-
preserves_lex_ordering,
432-
})
433-
}
434-
435404
fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
436405
write!(f, "{}(", self.name)?;
437406
for (i, expr) in self.args.iter().enumerate() {

0 commit comments

Comments
 (0)