Skip to content

Commit a59ffe8

Browse files
committed
simplify LambdaUDF coerce_value_types
1 parent d874db7 commit a59ffe8

4 files changed

Lines changed: 50 additions & 39 deletions

File tree

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -167,36 +167,42 @@ pub fn value_fields_with_lambda_udf<L: Clone>(
167167
LambdaTypeSignature::UserDefined => {
168168
let arg_types = current_fields
169169
.iter()
170-
.map(|p| match p {
171-
ValueOrLambda::Value(field) => {
172-
ValueOrLambda::Value(field.data_type().clone())
173-
}
174-
ValueOrLambda::Lambda(_) => ValueOrLambda::Lambda(()),
170+
.filter_map(|p| match p {
171+
ValueOrLambda::Value(field) => Some(field.data_type().clone()),
172+
ValueOrLambda::Lambda(_) => None,
175173
})
176174
.collect::<Vec<_>>();
177175

178176
let coerced_types = func.coerce_value_types(&arg_types)?;
179177

180-
std::iter::zip(current_fields, coerced_types)
181-
.map(|(field, coerce_to)| match (field, coerce_to) {
182-
(ValueOrLambda::Value(field), Some(coerce_to)) => {
183-
Ok(ValueOrLambda::Value(Arc::new(
184-
field.as_ref().clone().with_data_type(coerce_to),
185-
)))
178+
if coerced_types.len() != arg_types.len() {
179+
return plan_err!(
180+
"{} coerce_value_types should have returned {} items but returned {}",
181+
func.name(),
182+
arg_types.len(),
183+
coerced_types.len()
184+
);
185+
}
186+
187+
let mut coerced_types = coerced_types.into_iter();
188+
189+
Ok(current_fields
190+
.iter()
191+
.map(|current_field| match current_field {
192+
ValueOrLambda::Value(field) => {
193+
let data_type = coerced_types
194+
.next()
195+
.expect("coerced_types len should have been checked above");
196+
197+
ValueOrLambda::Value(Arc::new(
198+
field.as_ref().clone().with_data_type(data_type),
199+
))
186200
}
187-
(ValueOrLambda::Lambda(v), None) => {
188-
Ok(ValueOrLambda::Lambda(v.clone()))
201+
ValueOrLambda::Lambda(lambda) => {
202+
ValueOrLambda::Lambda(lambda.clone())
189203
}
190-
(ValueOrLambda::Value(_), None) => plan_err!(
191-
"{} coerce_values_types returned None for a value",
192-
func.name()
193-
),
194-
(ValueOrLambda::Lambda(_), Some(_)) => plan_err!(
195-
"{} coerce_values_types returned Some for a lambda",
196-
func.name()
197-
),
198204
})
199-
.collect()
205+
.collect())
200206
}
201207
LambdaTypeSignature::VariadicAny => Ok(current_fields.to_vec()),
202208
LambdaTypeSignature::Any(number) => {

datafusion/expr/src/udlf.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -606,21 +606,21 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync {
606606
/// See the [type coercion module](crate::type_coercion)
607607
/// documentation for more details on type coercion
608608
///
609-
/// For example, if your function requires a floating point arguments, but the user calls
610-
/// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]`
611-
/// to ensure the argument is converted to `1::double`
609+
/// For example, if your function requires a contiguous list argument, but the user calls
610+
/// it like `my_func(c, v -> v+2)` (i.e. with `c` as a ListView), coerce_types can return `[DataType::List(..)]`
611+
/// to ensure the argument is converted to a List
612612
///
613613
/// # Parameters
614-
/// * `arg_types`: The argument types of the arguments this function with
614+
/// * `arg_types`: The argument types of the value arguments of this function, excluding lambdas
615615
///
616616
/// # Return value
617617
/// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
618618
/// arguments to these specific types.
619-
fn coerce_value_types(
620-
&self,
621-
_arg_types: &[ValueOrLambda<DataType, ()>],
622-
) -> Result<Vec<Option<DataType>>> {
623-
not_impl_err!("Function {} does not implement coerce_types", self.name())
619+
fn coerce_value_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
620+
not_impl_err!(
621+
"Function {} does not implement coerce_value_types",
622+
self.name()
623+
)
624624
}
625625

626626
/// Returns the documentation for this Lambda UDF.

datafusion/functions-nested/src/array_transform.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,16 @@ impl LambdaUDF for ArrayTransform {
100100
&self.signature
101101
}
102102

103-
fn coerce_value_types(
104-
&self,
105-
arg_types: &[ValueOrLambda<DataType, ()>],
106-
) -> Result<Vec<Option<DataType>>> {
107-
let (list, _lambda) = value_lambda_pair(self.name(), arg_types)?;
103+
fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
104+
let list = if arg_types.len() == 1 {
105+
&arg_types[0]
106+
} else {
107+
return plan_err!(
108+
"{} function requires 1 value arguments, got {}",
109+
self.name(),
110+
arg_types.len()
111+
);
112+
};
108113

109114
let coerced = match list {
110115
DataType::List(_)
@@ -121,7 +126,7 @@ impl LambdaUDF for ArrayTransform {
121126
}
122127
};
123128

124-
Ok(vec![Some(coerced), None])
129+
Ok(vec![coerced])
125130
}
126131

127132
fn lambdas_parameters(

datafusion/sqllogictest/test_files/lambda.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,13 @@ select array_transform(arrow_cast(t.list, 'ListView(Int32)'), a -> a+1) from t;
245245
query error
246246
select array_transform();
247247
----
248-
DataFusion error: Execution error: array_transform function requires 2 arguments, got 0
248+
DataFusion error: Error during planning: array_transform function requires 1 value arguments, got 0
249249

250250

251251
query error DataFusion error: Error during planning: array_transform expected a list as first argument, got Int64
252252
select array_transform(1, v -> v*2);
253253

254-
query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(\(\)\) and Value\(List\(Field \{ data_type: Int64, nullable: true \}\)\)
254+
query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(\(\)\) and Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\)
255255
select array_transform(v -> v*2, [1, 2]);
256256

257257
query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 2

0 commit comments

Comments
 (0)