Skip to content

Commit 3bf4048

Browse files
authored
Fix parsing of equality binop in function argument (apache#1182)
1 parent e976a2e commit 3bf4048

5 files changed

Lines changed: 64 additions & 18 deletions

File tree

src/dialect/duckdb.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,8 @@ impl Dialect for DuckDbDialect {
3333
fn supports_group_by_expr(&self) -> bool {
3434
true
3535
}
36+
37+
fn supports_named_fn_args_with_eq_operator(&self) -> bool {
38+
true
39+
}
3640
}

src/dialect/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ pub trait Dialect: Debug + Any {
143143
fn supports_start_transaction_modifier(&self) -> bool {
144144
false
145145
}
146+
/// Returns true if the dialect supports named arguments of the form FUN(a = '1', b = '2').
147+
fn supports_named_fn_args_with_eq_operator(&self) -> bool {
148+
false
149+
}
146150
/// Returns true if the dialect has a CONVERT function which accepts a type first
147151
/// and an expression second, e.g. `CONVERT(varchar, 1)`
148152
fn convert_type_before_value(&self) -> bool {

src/parser/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8631,7 +8631,9 @@ impl<'a> Parser<'a> {
86318631
arg,
86328632
operator: FunctionArgOperator::RightArrow,
86338633
})
8634-
} else if self.peek_nth_token(1) == Token::Eq {
8634+
} else if self.dialect.supports_named_fn_args_with_eq_operator()
8635+
&& self.peek_nth_token(1) == Token::Eq
8636+
{
86358637
let name = self.parse_identifier(false)?;
86368638

86378639
self.expect_token(&Token::Eq)?;

src/test_utils.rs

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl TestedDialects {
6868
}
6969
Some((dialect, parsed))
7070
})
71-
.unwrap()
71+
.expect("tested dialects cannot be empty")
7272
.1
7373
}
7474

@@ -195,15 +195,6 @@ impl TestedDialects {
195195

196196
/// Returns all available dialects.
197197
pub fn all_dialects() -> TestedDialects {
198-
all_dialects_except(|_| false)
199-
}
200-
201-
/// Returns available dialects. The `except` predicate is used
202-
/// to filter out specific dialects.
203-
pub fn all_dialects_except<F>(except: F) -> TestedDialects
204-
where
205-
F: Fn(&dyn Dialect) -> bool,
206-
{
207198
let all_dialects = vec![
208199
Box::new(GenericDialect {}) as Box<dyn Dialect>,
209200
Box::new(PostgreSqlDialect {}) as Box<dyn Dialect>,
@@ -218,14 +209,30 @@ where
218209
Box::new(DuckDbDialect {}) as Box<dyn Dialect>,
219210
];
220211
TestedDialects {
221-
dialects: all_dialects
222-
.into_iter()
223-
.filter(|d| !except(d.as_ref()))
224-
.collect(),
212+
dialects: all_dialects,
225213
options: None,
226214
}
227215
}
228216

217+
/// Returns all dialects matching the given predicate.
218+
pub fn all_dialects_where<F>(predicate: F) -> TestedDialects
219+
where
220+
F: Fn(&dyn Dialect) -> bool,
221+
{
222+
let mut dialects = all_dialects();
223+
dialects.dialects.retain(|d| predicate(&**d));
224+
dialects
225+
}
226+
227+
/// Returns available dialects. The `except` predicate is used
228+
/// to filter out specific dialects.
229+
pub fn all_dialects_except<F>(except: F) -> TestedDialects
230+
where
231+
F: Fn(&dyn Dialect) -> bool,
232+
{
233+
all_dialects_where(|d| !except(d))
234+
}
235+
229236
pub fn assert_eq_vec<T: ToString>(expected: &[&str], actual: &[T]) {
230237
assert_eq!(
231238
expected,

tests/sqlparser_common.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ use sqlparser::keywords::ALL_KEYWORDS;
3333
use sqlparser::parser::{Parser, ParserError, ParserOptions};
3434
use sqlparser::tokenizer::Tokenizer;
3535
use test_utils::{
36-
all_dialects, alter_table_op, assert_eq_vec, expr_from_projection, join, number, only, table,
37-
table_alias, TestedDialects,
36+
all_dialects, all_dialects_where, alter_table_op, assert_eq_vec, expr_from_projection, join,
37+
number, only, table, table_alias, TestedDialects,
3838
};
3939

4040
#[macro_use]
@@ -4045,7 +4045,9 @@ fn parse_named_argument_function() {
40454045
#[test]
40464046
fn parse_named_argument_function_with_eq_operator() {
40474047
let sql = "SELECT FUN(a = '1', b = '2') FROM foo";
4048-
let select = verified_only_select(sql);
4048+
4049+
let select = all_dialects_where(|d| d.supports_named_fn_args_with_eq_operator())
4050+
.verified_only_select(sql);
40494051
assert_eq!(
40504052
&Expr::Function(Function {
40514053
name: ObjectName(vec![Ident::new("FUN")]),
@@ -4074,6 +4076,33 @@ fn parse_named_argument_function_with_eq_operator() {
40744076
}),
40754077
expr_from_projection(only(&select.projection))
40764078
);
4079+
4080+
// Ensure that bar = 42 in a function argument parses as an equality binop
4081+
// rather than a named function argument.
4082+
assert_eq!(
4083+
all_dialects_except(|d| d.supports_named_fn_args_with_eq_operator())
4084+
.verified_expr("foo(bar = 42)"),
4085+
Expr::Function(Function {
4086+
name: ObjectName(vec![Ident::new("foo")]),
4087+
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
4088+
Expr::BinaryOp {
4089+
left: Box::new(Expr::Identifier(Ident::new("bar"))),
4090+
op: BinaryOperator::Eq,
4091+
right: Box::new(Expr::Value(number("42"))),
4092+
},
4093+
))],
4094+
filter: None,
4095+
null_treatment: None,
4096+
over: None,
4097+
distinct: false,
4098+
special: false,
4099+
order_by: vec![],
4100+
})
4101+
);
4102+
4103+
// TODO: should this parse for all dialects?
4104+
all_dialects_except(|d| d.supports_named_fn_args_with_eq_operator())
4105+
.verified_expr("iff(1 = 1, 1, 0)");
40774106
}
40784107

40794108
#[test]

0 commit comments

Comments
 (0)