Skip to content

Commit 5f90cef

Browse files
committed
Snowflake: Add support for Lambda functions
1 parent 8e36e8e commit 5f90cef

6 files changed

Lines changed: 93 additions & 8 deletions

File tree

src/ast/mod.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,7 @@ impl fmt::Display for AccessExpr {
14231423
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
14241424
pub struct LambdaFunction {
14251425
/// The parameters to the lambda function.
1426-
pub params: OneOrManyWithParens<Ident>,
1426+
pub params: OneOrManyWithParens<LambdaFunctionParameter>,
14271427
/// The body of the lambda function.
14281428
pub body: Box<Expr>,
14291429
/// The syntax style used to write the lambda function.
@@ -1448,6 +1448,26 @@ impl fmt::Display for LambdaFunction {
14481448
}
14491449
}
14501450

1451+
/// A parameter to a lambda function, optionally with a data type.
1452+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
1453+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1454+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
1455+
pub struct LambdaFunctionParameter {
1456+
/// The name of the parameter
1457+
pub name: Ident,
1458+
/// The optional data type of the parameter
1459+
pub data_type: Option<DataType>,
1460+
}
1461+
1462+
impl fmt::Display for LambdaFunctionParameter {
1463+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1464+
match &self.data_type {
1465+
Some(dt) => write!(f, "{} {}", self.name, dt),
1466+
None => write!(f, "{}", self.name),
1467+
}
1468+
}
1469+
}
1470+
14511471
/// The syntax style for a lambda function.
14521472
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Copy)]
14531473
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]

src/dialect/snowflake.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,11 @@ impl Dialect for SnowflakeDialect {
662662
fn supports_select_wildcard_rename(&self) -> bool {
663663
true
664664
}
665+
666+
/// See <https://docs.snowflake.com/en/user-guide/querying-semistructured#label-higher-order-functions>
667+
fn supports_lambda_functions(&self) -> bool {
668+
true
669+
}
665670
}
666671

667672
// Peeks ahead to identify tokens that are expected after

src/parser/mod.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,7 +1609,25 @@ impl<'a> Parser<'a> {
16091609
Token::Arrow if self.dialect.supports_lambda_functions() => {
16101610
self.expect_token(&Token::Arrow)?;
16111611
Ok(Expr::Lambda(LambdaFunction {
1612-
params: OneOrManyWithParens::One(w.to_ident(w_span)),
1612+
params: OneOrManyWithParens::One(LambdaFunctionParameter {
1613+
name: w.to_ident(w_span),
1614+
data_type: None,
1615+
}),
1616+
body: Box::new(self.parse_expr()?),
1617+
syntax: LambdaSyntax::Arrow,
1618+
}))
1619+
}
1620+
Token::Word(_)
1621+
if self.peek_nth_token(1).token == Token::Arrow
1622+
&& self.dialect.supports_lambda_functions() =>
1623+
{
1624+
let first_param = LambdaFunctionParameter {
1625+
name: w.to_ident(self.peek_nth_token_ref(1).span),
1626+
data_type: self.maybe_parse(|parser| parser.parse_data_type())?,
1627+
};
1628+
self.expect_token(&Token::Arrow)?;
1629+
Ok(Expr::Lambda(LambdaFunction {
1630+
params: OneOrManyWithParens::One(first_param),
16131631
body: Box::new(self.parse_expr()?),
16141632
syntax: LambdaSyntax::Arrow,
16151633
}))
@@ -2195,7 +2213,12 @@ impl<'a> Parser<'a> {
21952213
return Ok(None);
21962214
}
21972215
self.maybe_parse(|p| {
2198-
let params = p.parse_comma_separated(|p| p.parse_identifier())?;
2216+
let params = p.parse_comma_separated(|p| {
2217+
Ok(LambdaFunctionParameter {
2218+
name: p.parse_identifier()?,
2219+
data_type: None,
2220+
})
2221+
})?;
21992222
p.expect_token(&Token::RParen)?;
22002223
p.expect_token(&Token::Arrow)?;
22012224
let expr = p.parse_expr()?;
@@ -2220,12 +2243,22 @@ impl<'a> Parser<'a> {
22202243
// Parse the parameters: either a single identifier or comma-separated identifiers
22212244
let params = if self.consume_token(&Token::LParen) {
22222245
// Parenthesized parameters: (x, y)
2223-
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
2246+
let params = self.parse_comma_separated(|p| {
2247+
Ok(LambdaFunctionParameter {
2248+
name: p.parse_identifier()?,
2249+
data_type: None,
2250+
})
2251+
})?;
22242252
self.expect_token(&Token::RParen)?;
22252253
OneOrManyWithParens::Many(params)
22262254
} else {
22272255
// Unparenthesized parameters: x or x, y
2228-
let params = self.parse_comma_separated(|p| p.parse_identifier())?;
2256+
let params = self.parse_comma_separated(|p| {
2257+
Ok(LambdaFunctionParameter {
2258+
name: p.parse_identifier()?,
2259+
data_type: None,
2260+
})
2261+
})?;
22292262
if params.len() == 1 {
22302263
OneOrManyWithParens::One(params.into_iter().next().unwrap())
22312264
} else {

tests/sqlparser_common.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15925,7 +15925,16 @@ fn test_lambdas() {
1592515925
]
1592615926
),
1592715927
Expr::Lambda(LambdaFunction {
15928-
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
15928+
params: OneOrManyWithParens::Many(vec![
15929+
LambdaFunctionParameter {
15930+
name: Ident::new("p1"),
15931+
data_type: None
15932+
},
15933+
LambdaFunctionParameter {
15934+
name: Ident::new("p2"),
15935+
data_type: None
15936+
}
15937+
]),
1592915938
body: Box::new(Expr::Case {
1593015939
case_token: AttachedToken::empty(),
1593115940
end_token: AttachedToken::empty(),

tests/sqlparser_databricks.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ fn test_databricks_exists() {
7272
]
7373
),
7474
Expr::Lambda(LambdaFunction {
75-
params: OneOrManyWithParens::One(Ident::new("x")),
75+
params: OneOrManyWithParens::One(LambdaFunctionParameter {
76+
name: Ident::new("x"),
77+
data_type: None
78+
}),
7679
body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x"))))),
7780
syntax: LambdaSyntax::Arrow,
7881
})
@@ -109,7 +112,16 @@ fn test_databricks_lambdas() {
109112
]
110113
),
111114
Expr::Lambda(LambdaFunction {
112-
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
115+
params: OneOrManyWithParens::Many(vec![
116+
LambdaFunctionParameter {
117+
name: Ident::new("p1"),
118+
data_type: None
119+
},
120+
LambdaFunctionParameter {
121+
name: Ident::new("p2"),
122+
data_type: None
123+
}
124+
]),
113125
body: Box::new(Expr::Case {
114126
case_token: AttachedToken::empty(),
115127
end_token: AttachedToken::empty(),

tests/sqlparser_snowflake.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4878,3 +4878,9 @@ fn test_truncate_table_if_exists() {
48784878
snowflake().verified_stmt("TRUNCATE TABLE my_table");
48794879
snowflake().verified_stmt("TRUNCATE IF EXISTS my_table");
48804880
}
4881+
4882+
#[test]
4883+
fn test_snowflake_lambda() {
4884+
snowflake().verified_expr("TRANSFORM([1, 2, 3], a -> a * 2)");
4885+
snowflake().verified_expr("TRANSFORM([1, 2, 3], a INT -> a * 2)");
4886+
}

0 commit comments

Comments
 (0)