Skip to content

Commit 1a45316

Browse files
committed
fix: case expression spans to include leading and trailing keywords
1 parent 5327f0c commit 1a45316

5 files changed

Lines changed: 41 additions & 9 deletions

File tree

src/ast/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,8 @@ pub enum Expr {
967967
/// not `< 0` nor `1, 2, 3` as allowed in a `<simple when clause>` per
968968
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
969969
Case {
970+
case_token: AttachedToken,
971+
end_token: AttachedToken,
970972
operand: Option<Box<Expr>>,
971973
conditions: Vec<CaseWhen>,
972974
else_result: Option<Box<Expr>>,
@@ -1675,6 +1677,8 @@ impl fmt::Display for Expr {
16751677
}
16761678
Expr::Function(fun) => fun.fmt(f),
16771679
Expr::Case {
1680+
case_token: _,
1681+
end_token: _,
16781682
operand,
16791683
conditions,
16801684
else_result,

src/ast/spans.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,18 +1566,24 @@ impl Spanned for Expr {
15661566
),
15671567
Expr::Prefixed { value, .. } => value.span(),
15681568
Expr::Case {
1569+
case_token,
1570+
end_token,
15691571
operand,
15701572
conditions,
15711573
else_result,
15721574
} => union_spans(
1573-
operand
1574-
.as_ref()
1575-
.map(|i| i.span())
1576-
.into_iter()
1577-
.chain(conditions.iter().flat_map(|case_when| {
1578-
[case_when.condition.span(), case_when.result.span()]
1579-
}))
1580-
.chain(else_result.as_ref().map(|i| i.span())),
1575+
iter::once(case_token.0.span)
1576+
.chain(
1577+
operand
1578+
.as_ref()
1579+
.map(|i| i.span())
1580+
.into_iter()
1581+
.chain(conditions.iter().flat_map(|case_when| {
1582+
[case_when.condition.span(), case_when.result.span()]
1583+
}))
1584+
.chain(else_result.as_ref().map(|i| i.span())),
1585+
)
1586+
.chain(iter::once(end_token.0.span)),
15811587
),
15821588
Expr::Exists { subquery, .. } => subquery.span(),
15831589
Expr::Subquery(query) => query.span(),

src/parser/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2274,6 +2274,7 @@ impl<'a> Parser<'a> {
22742274
}
22752275

22762276
pub fn parse_case_expr(&mut self) -> Result<Expr, ParserError> {
2277+
let case_token = AttachedToken(self.get_previous_token().clone());
22772278
let mut operand = None;
22782279
if !self.parse_keyword(Keyword::WHEN) {
22792280
operand = Some(Box::new(self.parse_expr()?));
@@ -2294,8 +2295,10 @@ impl<'a> Parser<'a> {
22942295
} else {
22952296
None
22962297
};
2297-
self.expect_keyword_is(Keyword::END)?;
2298+
let end_token = AttachedToken(self.expect_keyword(Keyword::END)?);
22982299
Ok(Expr::Case {
2300+
case_token,
2301+
end_token,
22992302
operand,
23002303
conditions,
23012304
else_result,

tests/sqlparser_common.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6861,6 +6861,8 @@ fn parse_searched_case_expr() {
68616861
let select = verified_only_select(sql);
68626862
assert_eq!(
68636863
&Case {
6864+
case_token: AttachedToken::empty(),
6865+
end_token: AttachedToken::empty(),
68646866
operand: None,
68656867
conditions: vec![
68666868
CaseWhen {
@@ -6900,6 +6902,8 @@ fn parse_simple_case_expr() {
69006902
use self::Expr::{Case, Identifier};
69016903
assert_eq!(
69026904
&Case {
6905+
case_token: AttachedToken::empty(),
6906+
end_token: AttachedToken::empty(),
69036907
operand: Some(Box::new(Identifier(Ident::new("foo")))),
69046908
conditions: vec![CaseWhen {
69056909
condition: Expr::value(number("1")),
@@ -14464,6 +14468,16 @@ fn test_case_statement_span() {
1446414468
);
1446514469
}
1446614470

14471+
#[test]
14472+
fn test_case_expr_span() {
14473+
let sql = "CASE 1 WHEN 2 THEN 3 ELSE 4 END";
14474+
let mut parser = Parser::new(&GenericDialect {}).try_with_sql(sql).unwrap();
14475+
assert_eq!(
14476+
parser.parse_expr().unwrap().span(),
14477+
Span::new(Location::new(1, 1), Location::new(1, sql.len() as u64 + 1))
14478+
);
14479+
}
14480+
1446714481
#[test]
1446814482
fn parse_if_statement() {
1446914483
let dialects = all_dialects_except(|d| d.is::<MsSqlDialect>());
@@ -14642,6 +14656,8 @@ fn test_lambdas() {
1464214656
Expr::Lambda(LambdaFunction {
1464314657
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
1464414658
body: Box::new(Expr::Case {
14659+
case_token: AttachedToken::empty(),
14660+
end_token: AttachedToken::empty(),
1464514661
operand: None,
1464614662
conditions: vec![
1464714663
CaseWhen {

tests/sqlparser_databricks.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use sqlparser::ast::helpers::attached_token::AttachedToken;
1819
use sqlparser::ast::*;
1920
use sqlparser::dialect::{DatabricksDialect, GenericDialect};
2021
use sqlparser::parser::ParserError;
@@ -108,6 +109,8 @@ fn test_databricks_lambdas() {
108109
Expr::Lambda(LambdaFunction {
109110
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
110111
body: Box::new(Expr::Case {
112+
case_token: AttachedToken::empty(),
113+
end_token: AttachedToken::empty(),
111114
operand: None,
112115
conditions: vec![
113116
CaseWhen {

0 commit comments

Comments
 (0)