Skip to content

Commit 2e5532c

Browse files
committed
Snowflake: Add support for Lambda functions
1 parent 0924f3a commit 2e5532c

6 files changed

Lines changed: 93 additions & 8 deletions

File tree

src/ast/mod.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,7 @@ impl fmt::Display for AccessExpr {
14211421
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
14221422
pub struct LambdaFunction {
14231423
/// The parameters to the lambda function.
1424-
pub params: OneOrManyWithParens<Ident>,
1424+
pub params: OneOrManyWithParens<LambdaFunctionParameter>,
14251425
/// The body of the lambda function.
14261426
pub body: Box<Expr>,
14271427
/// The syntax style used to write the lambda function.
@@ -1446,6 +1446,26 @@ impl fmt::Display for LambdaFunction {
14461446
}
14471447
}
14481448

1449+
/// A parameter to a lambda function, optionally with a data type.
1450+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
1451+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1452+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
1453+
pub struct LambdaFunctionParameter {
1454+
/// The name of the parameter
1455+
pub name: Ident,
1456+
/// The optional data type of the parameter
1457+
pub data_type: Option<DataType>,
1458+
}
1459+
1460+
impl fmt::Display for LambdaFunctionParameter {
1461+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1462+
match &self.data_type {
1463+
Some(dt) => write!(f, "{} {}", self.name, dt),
1464+
None => write!(f, "{}", self.name),
1465+
}
1466+
}
1467+
}
1468+
14491469
/// The syntax style for a lambda function.
14501470
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Copy)]
14511471
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]

src/dialect/snowflake.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,11 @@ impl Dialect for SnowflakeDialect {
632632
fn supports_select_wildcard_rename(&self) -> bool {
633633
true
634634
}
635+
636+
/// See <https://docs.snowflake.com/en/user-guide/querying-semistructured#label-higher-order-functions>
637+
fn supports_lambda_functions(&self) -> bool {
638+
true
639+
}
635640
}
636641

637642
// Peeks ahead to identify tokens that are expected after

src/parser/mod.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,7 +1606,25 @@ impl<'a> Parser<'a> {
16061606
Token::Arrow if self.dialect.supports_lambda_functions() => {
16071607
self.expect_token(&Token::Arrow)?;
16081608
Ok(Expr::Lambda(LambdaFunction {
1609-
params: OneOrManyWithParens::One(w.to_ident(w_span)),
1609+
params: OneOrManyWithParens::One(LambdaFunctionParameter {
1610+
name: w.to_ident(w_span),
1611+
data_type: None,
1612+
}),
1613+
body: Box::new(self.parse_expr()?),
1614+
syntax: LambdaSyntax::Arrow,
1615+
}))
1616+
}
1617+
Token::Word(_)
1618+
if self.peek_nth_token(1).token == Token::Arrow
1619+
&& self.dialect.supports_lambda_functions() =>
1620+
{
1621+
let first_param = LambdaFunctionParameter {
1622+
name: w.to_ident(self.peek_nth_token_ref(1).span),
1623+
data_type: self.maybe_parse(|parser| parser.parse_data_type())?,
1624+
};
1625+
self.expect_token(&Token::Arrow)?;
1626+
Ok(Expr::Lambda(LambdaFunction {
1627+
params: OneOrManyWithParens::One(first_param),
16101628
body: Box::new(self.parse_expr()?),
16111629
syntax: LambdaSyntax::Arrow,
16121630
}))
@@ -2192,7 +2210,12 @@ impl<'a> Parser<'a> {
21922210
return Ok(None);
21932211
}
21942212
self.maybe_parse(|p| {
2195-
let params = p.parse_comma_separated(|p| p.parse_identifier())?;
2213+
let params = p.parse_comma_separated(|p| {
2214+
Ok(LambdaFunctionParameter {
2215+
name: p.parse_identifier()?,
2216+
data_type: None,
2217+
})
2218+
})?;
21962219
p.expect_token(&Token::RParen)?;
21972220
p.expect_token(&Token::Arrow)?;
21982221
let expr = p.parse_expr()?;
@@ -2217,12 +2240,22 @@ impl<'a> Parser<'a> {
22172240
// Parse the parameters: either a single identifier or comma-separated identifiers
22182241
let params = if self.consume_token(&Token::LParen) {
22192242
// Parenthesized parameters: (x, y)
2220-
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
2243+
let params = self.parse_comma_separated(|p| {
2244+
Ok(LambdaFunctionParameter {
2245+
name: p.parse_identifier()?,
2246+
data_type: None,
2247+
})
2248+
})?;
22212249
self.expect_token(&Token::RParen)?;
22222250
OneOrManyWithParens::Many(params)
22232251
} else {
22242252
// Unparenthesized parameters: x or x, y
2225-
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
2253+
let params = self.parse_comma_separated(|p| {
2254+
Ok(LambdaFunctionParameter {
2255+
name: p.parse_identifier()?,
2256+
data_type: None,
2257+
})
2258+
})?;
22262259
if params.len() == 1 {
22272260
OneOrManyWithParens::One(params.into_iter().next().unwrap())
22282261
} else {

tests/sqlparser_common.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15872,7 +15872,16 @@ fn test_lambdas() {
1587215872
]
1587315873
),
1587415874
Expr::Lambda(LambdaFunction {
15875-
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
15875+
params: OneOrManyWithParens::Many(vec![
15876+
LambdaFunctionParameter {
15877+
name: Ident::new("p1"),
15878+
data_type: None
15879+
},
15880+
LambdaFunctionParameter {
15881+
name: Ident::new("p2"),
15882+
data_type: None
15883+
}
15884+
]),
1587615885
body: Box::new(Expr::Case {
1587715886
case_token: AttachedToken::empty(),
1587815887
end_token: AttachedToken::empty(),

tests/sqlparser_databricks.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ fn test_databricks_exists() {
7272
]
7373
),
7474
Expr::Lambda(LambdaFunction {
75-
params: OneOrManyWithParens::One(Ident::new("x")),
75+
params: OneOrManyWithParens::One(LambdaFunctionParameter {
76+
name: Ident::new("x"),
77+
data_type: None
78+
}),
7679
body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x"))))),
7780
syntax: LambdaSyntax::Arrow,
7881
})
@@ -109,7 +112,16 @@ fn test_databricks_lambdas() {
109112
]
110113
),
111114
Expr::Lambda(LambdaFunction {
112-
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
115+
params: OneOrManyWithParens::Many(vec![
116+
LambdaFunctionParameter {
117+
name: Ident::new("p1"),
118+
data_type: None
119+
},
120+
LambdaFunctionParameter {
121+
name: Ident::new("p2"),
122+
data_type: None
123+
}
124+
]),
113125
body: Box::new(Expr::Case {
114126
case_token: AttachedToken::empty(),
115127
end_token: AttachedToken::empty(),

tests/sqlparser_snowflake.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4557,3 +4557,9 @@ fn test_truncate_table_if_exists() {
45574557
snowflake().verified_stmt("TRUNCATE TABLE my_table");
45584558
snowflake().verified_stmt("TRUNCATE IF EXISTS my_table");
45594559
}
4560+
4561+
#[test]
4562+
fn test_snowflake_lambda() {
4563+
snowflake().verified_expr("TRANSFORM([1, 2, 3], a -> a * 2)");
4564+
snowflake().verified_expr("TRANSFORM([1, 2, 3], a INT -> a * 2)");
4565+
}

0 commit comments

Comments
 (0)