Skip to content

Commit a4de535

Browse files
committed
Expand require_semicolon_stmt_delimiter parser option & tests
- a corresponding `supports_statements_without_semicolon_delimiter` Dialect trait function - this is optional for SQL Server, so it's set to `true` for that dialect - for the implementation, `RETURN` parsing needs to be tightened up to avoid ambiguity & tests that formerly asserted "end of statement" now maybe need to assert "an SQL statement" - a new `assert_err_parse_statements` splits the dialects based on semicolon requirements & asserts the expected error message accordingly
1 parent 64f4b1f commit a4de535

7 files changed

Lines changed: 647 additions & 136 deletions

File tree

src/dialect/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,6 +1657,11 @@ pub trait Dialect: Debug + Any {
16571657
fn supports_comma_separated_trim(&self) -> bool {
16581658
false
16591659
}
1660+
1661+
/// Returns true if the dialect supports parsing statements without a semicolon delimiter.
1662+
fn supports_statements_without_semicolon_delimiter(&self) -> bool {
1663+
false
1664+
}
16601665
}
16611666

16621667
/// Operators for which precedence must be defined.

src/dialect/mssql.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ impl Dialect for MsSqlDialect {
6666
}
6767

6868
fn supports_connect_by(&self) -> bool {
69-
true
69+
false
7070
}
7171

7272
fn supports_eq_alias_assignment(&self) -> bool {
@@ -122,6 +122,10 @@ impl Dialect for MsSqlDialect {
122122
true
123123
}
124124

125+
fn supports_statements_without_semicolon_delimiter(&self) -> bool {
126+
true
127+
}
128+
125129
/// See <https://learn.microsoft.com/en-us/sql/relational-databases/security/authentication-access/server-level-roles>
126130
fn get_reserved_grantees_types(&self) -> &[GranteesType] {
127131
&[GranteesType::Public]
@@ -378,6 +382,9 @@ impl MsSqlDialect {
378382
) -> Result<Vec<Statement>, ParserError> {
379383
let mut stmts = Vec::new();
380384
loop {
385+
while let Token::SemiColon = parser.peek_token_ref().token {
386+
parser.advance_token();
387+
}
381388
if let Token::EOF = parser.peek_token_ref().token {
382389
break;
383390
}

src/keywords.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
12151215
Keyword::SEMI,
12161216
Keyword::RETURNING,
12171217
Keyword::OUTPUT,
1218+
Keyword::RETURN,
12181219
Keyword::ASOF,
12191220
Keyword::MATCH_CONDITION,
12201221
// for MSSQL-specific OUTER APPLY (seems reserved in most dialects)
@@ -1270,6 +1271,7 @@ pub const RESERVED_FOR_COLUMN_ALIAS: &[Keyword] = &[
12701271
Keyword::DISTRIBUTE,
12711272
Keyword::RETURNING,
12721273
Keyword::VALUES,
1274+
Keyword::RETURN,
12731275
// Reserved only as a column alias in the `SELECT` clause
12741276
Keyword::FROM,
12751277
Keyword::INTO,
@@ -1284,6 +1286,7 @@ pub const RESERVED_FOR_TABLE_FACTOR: &[Keyword] = &[
12841286
Keyword::LIMIT,
12851287
Keyword::HAVING,
12861288
Keyword::WHERE,
1289+
Keyword::RETURN,
12871290
];
12881291

12891292
/// Global list of reserved keywords that cannot be parsed as identifiers
@@ -1294,4 +1297,5 @@ pub const RESERVED_FOR_IDENTIFIER: &[Keyword] = &[
12941297
Keyword::INTERVAL,
12951298
Keyword::STRUCT,
12961299
Keyword::TRIM,
1300+
Keyword::RETURN,
12971301
];

src/parser/mod.rs

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,22 @@ impl ParserOptions {
288288
self.unescape = unescape;
289289
self
290290
}
291+
292+
/// Set if semicolon statement delimiters are required.
293+
///
294+
/// If this option is `true`, the following SQL will not parse. If the option is `false`, the SQL will parse.
295+
///
296+
/// ```sql
297+
/// SELECT 1
298+
/// SELECT 2
299+
/// ```
300+
pub fn with_require_semicolon_stmt_delimiter(
301+
mut self,
302+
require_semicolon_stmt_delimiter: bool,
303+
) -> Self {
304+
self.require_semicolon_stmt_delimiter = require_semicolon_stmt_delimiter;
305+
self
306+
}
291307
}
292308

293309
#[derive(Copy, Clone)]
@@ -384,7 +400,11 @@ impl<'a> Parser<'a> {
384400
state: ParserState::Normal,
385401
dialect,
386402
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
387-
options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()),
403+
options: ParserOptions::new()
404+
.with_trailing_commas(dialect.supports_trailing_commas())
405+
.with_require_semicolon_stmt_delimiter(
406+
!dialect.supports_statements_without_semicolon_delimiter(),
407+
),
388408
}
389409
}
390410

@@ -507,10 +527,10 @@ impl<'a> Parser<'a> {
507527
match &self.peek_token_ref().token {
508528
Token::EOF => break,
509529

510-
// end of statement
511-
Token::Word(word) => {
512-
if expecting_statement_delimiter && word.keyword == Keyword::END {
513-
break;
530+
// don't expect a semicolon statement delimiter after a newline when not otherwise required
531+
Token::Whitespace(Whitespace::Newline) => {
532+
if !self.options.require_semicolon_stmt_delimiter {
533+
expecting_statement_delimiter = false;
514534
}
515535
}
516536
_ => {}
@@ -522,7 +542,7 @@ impl<'a> Parser<'a> {
522542

523543
let statement = self.parse_statement()?;
524544
stmts.push(statement);
525-
expecting_statement_delimiter = true;
545+
expecting_statement_delimiter = self.options.require_semicolon_stmt_delimiter;
526546
}
527547
Ok(stmts)
528548
}
@@ -4987,6 +5007,9 @@ impl<'a> Parser<'a> {
49875007
) -> Result<Vec<Statement>, ParserError> {
49885008
let mut values = vec![];
49895009
loop {
5010+
// ignore empty statements (between successive statement delimiters)
5011+
while self.consume_token(&Token::SemiColon) {}
5012+
49905013
match &self.peek_nth_token_ref(0).token {
49915014
Token::EOF => break,
49925015
Token::Word(w) => {
@@ -4998,7 +5021,13 @@ impl<'a> Parser<'a> {
49985021
}
49995022

50005023
values.push(self.parse_statement()?);
5001-
self.expect_token(&Token::SemiColon)?;
5024+
5025+
if self.options.require_semicolon_stmt_delimiter {
5026+
self.expect_token(&Token::SemiColon)?;
5027+
}
5028+
5029+
// ignore empty statements (between successive statement delimiters)
5030+
while self.consume_token(&Token::SemiColon) {}
50025031
}
50035032
Ok(values)
50045033
}
@@ -19571,7 +19600,28 @@ impl<'a> Parser<'a> {
1957119600

1957219601
/// Parse [Statement::Return]
1957319602
fn parse_return(&mut self) -> Result<Statement, ParserError> {
19574-
match self.maybe_parse(|p| p.parse_expr())? {
19603+
let rs = self.maybe_parse(|p| {
19604+
let expr = p.parse_expr()?;
19605+
19606+
match &expr {
19607+
Expr::Value(_)
19608+
| Expr::Function(_)
19609+
| Expr::UnaryOp { .. }
19610+
| Expr::BinaryOp { .. }
19611+
| Expr::Case { .. }
19612+
| Expr::Cast { .. }
19613+
| Expr::Convert { .. }
19614+
| Expr::Subquery(_) => Ok(expr),
19615+
// todo: how to retstrict to variables?
19616+
Expr::Identifier(id) if id.value.starts_with('@') => Ok(expr),
19617+
_ => parser_err!(
19618+
"Non-returnable expression found following RETURN",
19619+
p.peek_token().span.start
19620+
),
19621+
}
19622+
})?;
19623+
19624+
match rs {
1957519625
Some(expr) => Ok(Statement::Return(ReturnStatement {
1957619626
value: Some(ReturnStatementValue::Expr(expr)),
1957719627
})),

src/test_utils.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#[cfg(not(feature = "std"))]
2626
use alloc::{
2727
boxed::Box,
28+
format,
2829
string::{String, ToString},
2930
vec,
3031
vec::Vec,
@@ -186,6 +187,37 @@ impl TestedDialects {
186187
statements
187188
}
188189

190+
/// The same as [`statements_parse_to`] but it will strip semicolons from the SQL text.
191+
pub fn statements_without_semicolons_parse_to(
192+
&self,
193+
sql: &str,
194+
canonical: &str,
195+
) -> Vec<Statement> {
196+
let sql_without_semicolons = sql
197+
.replace("; ", " ")
198+
.replace(" ;", " ")
199+
.replace(";\n", "\n")
200+
.replace("\n;", "\n")
201+
.replace(";", " ");
202+
let statements = self
203+
.parse_sql_statements(&sql_without_semicolons)
204+
.expect(&sql_without_semicolons);
205+
if !canonical.is_empty() && sql != canonical {
206+
assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements);
207+
} else {
208+
assert_eq!(
209+
sql,
210+
statements
211+
.iter()
212+
// note: account for format_statement_list manually inserted semicolons
213+
.map(|s| s.to_string().trim_end_matches(";").to_string())
214+
.collect::<Vec<_>>()
215+
.join("; ")
216+
);
217+
}
218+
statements
219+
}
220+
189221
/// Ensures that `sql` parses as an [`Expr`], and that
190222
/// re-serializing the parse result produces canonical
191223
pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr {
@@ -319,6 +351,43 @@ where
319351
all_dialects_where(|d| !except(d))
320352
}
321353

354+
/// Returns all dialects that don't support statements without semicolon delimiters.
355+
/// (i.e. dialects that require semicolon delimiters.)
356+
pub fn all_dialects_requiring_semicolon_statement_delimiter() -> TestedDialects {
357+
let tested_dialects =
358+
all_dialects_except(|d| d.supports_statements_without_semicolon_delimiter());
359+
assert_ne!(tested_dialects.dialects.len(), 0);
360+
tested_dialects
361+
}
362+
363+
/// Returns all dialects that do support statements without semicolon delimiters.
364+
/// (i.e. dialects not requiring semicolon delimiters.)
365+
pub fn all_dialects_not_requiring_semicolon_statement_delimiter() -> TestedDialects {
366+
let tested_dialects =
367+
all_dialects_where(|d| d.supports_statements_without_semicolon_delimiter());
368+
assert_ne!(tested_dialects.dialects.len(), 0);
369+
tested_dialects
370+
}
371+
372+
/// Asserts an error for `parse_sql_statements`:
373+
/// - "end of statement" for dialects that require semicolon delimiters
374+
/// - "an SQL statement" for dialects that don't require semicolon delimiters.
375+
pub fn assert_err_parse_statements(sql: &str, found: &str) {
376+
assert_eq!(
377+
ParserError::ParserError(format!("Expected: end of statement, found: {found}")),
378+
all_dialects_requiring_semicolon_statement_delimiter()
379+
.parse_sql_statements(sql)
380+
.unwrap_err()
381+
);
382+
383+
assert_eq!(
384+
ParserError::ParserError(format!("Expected: an SQL statement, found: {found}")),
385+
all_dialects_not_requiring_semicolon_statement_delimiter()
386+
.parse_sql_statements(sql)
387+
.unwrap_err()
388+
);
389+
}
390+
322391
pub fn assert_eq_vec<T: ToString>(expected: &[&str], actual: &[T]) {
323392
assert_eq!(
324393
expected,

0 commit comments

Comments
 (0)