Skip to content

Commit 3a8c468

Browse files
committed
Snowflake: Add support for Lambda functions
1 parent 3ac5670 commit 3a8c468

6 files changed

Lines changed: 93 additions & 8 deletions

File tree

src/ast/mod.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1420,7 +1420,7 @@ impl fmt::Display for AccessExpr {
14201420
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
14211421
pub struct LambdaFunction {
14221422
/// The parameters to the lambda function.
1423-
pub params: OneOrManyWithParens<Ident>,
1423+
pub params: OneOrManyWithParens<LambdaFunctionParameter>,
14241424
/// The body of the lambda function.
14251425
pub body: Box<Expr>,
14261426
/// The syntax style used to write the lambda function.
@@ -1445,6 +1445,26 @@ impl fmt::Display for LambdaFunction {
14451445
}
14461446
}
14471447

1448+
/// A parameter to a lambda function, optionally with a data type.
1449+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
1450+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1451+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
1452+
pub struct LambdaFunctionParameter {
1453+
/// The name of the parameter
1454+
pub name: Ident,
1455+
/// The optional data type of the parameter
1456+
pub data_type: Option<DataType>,
1457+
}
1458+
1459+
impl fmt::Display for LambdaFunctionParameter {
1460+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1461+
match &self.data_type {
1462+
Some(dt) => write!(f, "{} {}", self.name, dt),
1463+
None => write!(f, "{}", self.name),
1464+
}
1465+
}
1466+
}
1467+
14481468
/// The syntax style for a lambda function.
14491469
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Copy)]
14501470
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]

src/dialect/snowflake.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,11 @@ impl Dialect for SnowflakeDialect {
631631
fn supports_select_wildcard_rename(&self) -> bool {
632632
true
633633
}
634+
635+
/// See <https://docs.snowflake.com/en/user-guide/querying-semistructured#label-higher-order-functions>
636+
fn supports_lambda_functions(&self) -> bool {
637+
true
638+
}
634639
}
635640

636641
// Peeks ahead to identify tokens that are expected after

src/parser/mod.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,7 +1599,25 @@ impl<'a> Parser<'a> {
15991599
Token::Arrow if self.dialect.supports_lambda_functions() => {
16001600
self.expect_token(&Token::Arrow)?;
16011601
Ok(Expr::Lambda(LambdaFunction {
1602-
params: OneOrManyWithParens::One(w.to_ident(w_span)),
1602+
params: OneOrManyWithParens::One(LambdaFunctionParameter {
1603+
name: w.to_ident(w_span),
1604+
data_type: None,
1605+
}),
1606+
body: Box::new(self.parse_expr()?),
1607+
syntax: LambdaSyntax::Arrow,
1608+
}))
1609+
}
1610+
Token::Word(_)
1611+
if self.peek_nth_token(1).token == Token::Arrow
1612+
&& self.dialect.supports_lambda_functions() =>
1613+
{
1614+
let first_param = LambdaFunctionParameter {
1615+
name: w.to_ident(self.peek_nth_token_ref(1).span),
1616+
data_type: self.maybe_parse(|parser| parser.parse_data_type())?,
1617+
};
1618+
self.expect_token(&Token::Arrow)?;
1619+
Ok(Expr::Lambda(LambdaFunction {
1620+
params: OneOrManyWithParens::One(first_param),
16031621
body: Box::new(self.parse_expr()?),
16041622
syntax: LambdaSyntax::Arrow,
16051623
}))
@@ -2185,7 +2203,12 @@ impl<'a> Parser<'a> {
21852203
return Ok(None);
21862204
}
21872205
self.maybe_parse(|p| {
2188-
let params = p.parse_comma_separated(|p| p.parse_identifier())?;
2206+
let params = p.parse_comma_separated(|p| {
2207+
Ok(LambdaFunctionParameter {
2208+
name: p.parse_identifier()?,
2209+
data_type: None,
2210+
})
2211+
})?;
21892212
p.expect_token(&Token::RParen)?;
21902213
p.expect_token(&Token::Arrow)?;
21912214
let expr = p.parse_expr()?;
@@ -2210,12 +2233,22 @@ impl<'a> Parser<'a> {
22102233
// Parse the parameters: either a single identifier or comma-separated identifiers
22112234
let params = if self.consume_token(&Token::LParen) {
22122235
// Parenthesized parameters: (x, y)
2213-
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
2236+
let params = self.parse_comma_separated(|p| {
2237+
Ok(LambdaFunctionParameter {
2238+
name: p.parse_identifier()?,
2239+
data_type: None,
2240+
})
2241+
})?;
22142242
self.expect_token(&Token::RParen)?;
22152243
OneOrManyWithParens::Many(params)
22162244
} else {
22172245
// Unparenthesized parameters: x or x, y
2218-
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
2246+
let params = self.parse_comma_separated(|p| {
2247+
Ok(LambdaFunctionParameter {
2248+
name: p.parse_identifier()?,
2249+
data_type: None,
2250+
})
2251+
})?;
22192252
if params.len() == 1 {
22202253
OneOrManyWithParens::One(params.into_iter().next().unwrap())
22212254
} else {

tests/sqlparser_common.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15721,7 +15721,16 @@ fn test_lambdas() {
1572115721
]
1572215722
),
1572315723
Expr::Lambda(LambdaFunction {
15724-
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
15724+
params: OneOrManyWithParens::Many(vec![
15725+
LambdaFunctionParameter {
15726+
name: Ident::new("p1"),
15727+
data_type: None
15728+
},
15729+
LambdaFunctionParameter {
15730+
name: Ident::new("p2"),
15731+
data_type: None
15732+
}
15733+
]),
1572515734
body: Box::new(Expr::Case {
1572615735
case_token: AttachedToken::empty(),
1572715736
end_token: AttachedToken::empty(),

tests/sqlparser_databricks.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ fn test_databricks_exists() {
7272
]
7373
),
7474
Expr::Lambda(LambdaFunction {
75-
params: OneOrManyWithParens::One(Ident::new("x")),
75+
params: OneOrManyWithParens::One(LambdaFunctionParameter {
76+
name: Ident::new("x"),
77+
data_type: None
78+
}),
7679
body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x"))))),
7780
syntax: LambdaSyntax::Arrow,
7881
})
@@ -109,7 +112,16 @@ fn test_databricks_lambdas() {
109112
]
110113
),
111114
Expr::Lambda(LambdaFunction {
112-
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
115+
params: OneOrManyWithParens::Many(vec![
116+
LambdaFunctionParameter {
117+
name: Ident::new("p1"),
118+
data_type: None
119+
},
120+
LambdaFunctionParameter {
121+
name: Ident::new("p2"),
122+
data_type: None
123+
}
124+
]),
113125
body: Box::new(Expr::Case {
114126
case_token: AttachedToken::empty(),
115127
end_token: AttachedToken::empty(),

tests/sqlparser_snowflake.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4559,3 +4559,9 @@ fn test_truncate_table_if_exists() {
45594559
snowflake().verified_stmt("TRUNCATE TABLE my_table");
45604560
snowflake().verified_stmt("TRUNCATE IF EXISTS my_table");
45614561
}
4562+
4563+
#[test]
4564+
fn test_snowflake_lambda() {
4565+
snowflake().verified_expr("TRANSFORM([1, 2, 3], a -> a * 2)");
4566+
snowflake().verified_expr("TRANSFORM([1, 2, 3], a INT -> a * 2)");
4567+
}

0 commit comments

Comments
 (0)