Skip to content

Commit 157a992

Browse files
Resolved bug in parse_function_arg
1 parent 7703fd0 commit 157a992

2 files changed

Lines changed: 183 additions & 5 deletions

File tree

src/parser/mod.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5199,12 +5199,22 @@ impl<'a> Parser<'a> {
51995199

52005200
// parse: [ argname ] argtype
52015201
let mut name = None;
5202+
let next_token = self.peek_token();
52025203
let mut data_type = self.parse_data_type()?;
5203-
if let DataType::Custom(n, _) = &data_type {
5204-
// the first token is actually a name
5205-
match n.0[0].clone() {
5206-
ObjectNamePart::Identifier(ident) => name = Some(ident),
5207-
}
5204+
5205+
// It may appear that the first token can be converted into a known
5206+
// type, but this could also be a collision as some types are only
5207+
// present in some dialects and therefore some type reserved keywords
5208+
// may be freely used as argument names in other dialects.
5209+
5210+
// To check whether the first token is a name or a type, we need to
5211+
// peek the next token, which if it is another type keyword, then the
5212+
// first token is a name and not a type in itself.
5213+
let potential_tokens = [Token::Eq, Token::RParen, Token::Comma];
5214+
if !self.peek_keyword(Keyword::DEFAULT)
5215+
&& !potential_tokens.contains(&self.peek_token().token)
5216+
{
5217+
name = Some(Ident::new(next_token.to_string()));
52085218
data_type = self.parse_data_type()?;
52095219
}
52105220

tests/sqlparser_postgres.rs

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4098,6 +4098,174 @@ fn parse_update_in_with_subquery() {
40984098
pg_and_generic().verified_stmt(r#"WITH "result" AS (UPDATE "Hero" SET "name" = 'Captain America', "number_of_movies" = "number_of_movies" + 1 WHERE "secret_identity" = 'Sam Wilson' RETURNING "id", "name", "secret_identity", "number_of_movies") SELECT * FROM "result""#);
40994099
}
41004100

4101+
#[test]
4102+
fn parser_create_function_with_args() {
4103+
let sql1 = r#"CREATE OR REPLACE FUNCTION check_strings_different(str1 VARCHAR, str2 VARCHAR) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
4104+
BEGIN
4105+
IF str1 <> str2 THEN
4106+
RETURN TRUE;
4107+
ELSE
4108+
RETURN FALSE;
4109+
END IF;
4110+
END;
4111+
$$"#;
4112+
4113+
assert_eq!(
4114+
pg_and_generic().verified_stmt(sql1),
4115+
Statement::CreateFunction(CreateFunction {
4116+
or_alter: false,
4117+
or_replace: true,
4118+
temporary: false,
4119+
name: ObjectName::from(vec![Ident::new("check_strings_different")]),
4120+
args: Some(vec![
4121+
OperateFunctionArg::with_name(
4122+
"str1",
4123+
DataType::Varchar(None),
4124+
),
4125+
OperateFunctionArg::with_name(
4126+
"str2",
4127+
DataType::Varchar(None),
4128+
),
4129+
]),
4130+
return_type: Some(DataType::Boolean),
4131+
language: Some("plpgsql".into()),
4132+
behavior: None,
4133+
called_on_null: None,
4134+
parallel: None,
4135+
function_body: Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
4136+
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF str1 <> str2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
4137+
))),
4138+
if_not_exists: false,
4139+
using: None,
4140+
determinism_specifier: None,
4141+
options: None,
4142+
remote_connection: None,
4143+
})
4144+
);
4145+
4146+
let sql2 = r#"CREATE OR REPLACE FUNCTION check_not_zero(int1 INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
4147+
BEGIN
4148+
IF int1 <> 0 THEN
4149+
RETURN TRUE;
4150+
ELSE
4151+
RETURN FALSE;
4152+
END IF;
4153+
END;
4154+
$$"#;
4155+
assert_eq!(
4156+
pg_and_generic().verified_stmt(sql2),
4157+
Statement::CreateFunction(CreateFunction {
4158+
or_alter: false,
4159+
or_replace: true,
4160+
temporary: false,
4161+
name: ObjectName::from(vec![Ident::new("check_not_zero")]),
4162+
args: Some(vec![
4163+
OperateFunctionArg::with_name(
4164+
"int1",
4165+
DataType::Int(None)
4166+
)
4167+
]),
4168+
return_type: Some(DataType::Boolean),
4169+
language: Some("plpgsql".into()),
4170+
behavior: None,
4171+
called_on_null: None,
4172+
parallel: None,
4173+
function_body: Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
4174+
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> 0 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
4175+
))),
4176+
if_not_exists: false,
4177+
using: None,
4178+
determinism_specifier: None,
4179+
options: None,
4180+
remote_connection: None,
4181+
})
4182+
);
4183+
4184+
let sql3 = r#"CREATE OR REPLACE FUNCTION check_values_different(a INT, b INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
4185+
BEGIN
4186+
IF a <> b THEN
4187+
RETURN TRUE;
4188+
ELSE
4189+
RETURN FALSE;
4190+
END IF;
4191+
END;
4192+
$$"#;
4193+
assert_eq!(
4194+
pg_and_generic().verified_stmt(sql3),
4195+
Statement::CreateFunction(CreateFunction {
4196+
or_alter: false,
4197+
or_replace: true,
4198+
temporary: false,
4199+
name: ObjectName::from(vec![Ident::new("check_values_different")]),
4200+
args: Some(vec![
4201+
OperateFunctionArg::with_name(
4202+
"a",
4203+
DataType::Int(None)
4204+
),
4205+
OperateFunctionArg::with_name(
4206+
"b",
4207+
DataType::Int(None)
4208+
),
4209+
]),
4210+
return_type: Some(DataType::Boolean),
4211+
language: Some("plpgsql".into()),
4212+
behavior: None,
4213+
called_on_null: None,
4214+
parallel: None,
4215+
function_body: Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
4216+
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF a <> b THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
4217+
))),
4218+
if_not_exists: false,
4219+
using: None,
4220+
determinism_specifier: None,
4221+
options: None,
4222+
remote_connection: None,
4223+
})
4224+
);
4225+
4226+
let sql4 = r#"CREATE OR REPLACE FUNCTION check_values_different(int1 INT, int2 INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
4227+
BEGIN
4228+
IF int1 <> int2 THEN
4229+
RETURN TRUE;
4230+
ELSE
4231+
RETURN FALSE;
4232+
END IF;
4233+
END;
4234+
$$"#;
4235+
assert_eq!(
4236+
pg_and_generic().verified_stmt(sql4),
4237+
Statement::CreateFunction(CreateFunction {
4238+
or_alter: false,
4239+
or_replace: true,
4240+
temporary: false,
4241+
name: ObjectName::from(vec![Ident::new("check_values_different")]),
4242+
args: Some(vec![
4243+
OperateFunctionArg::with_name(
4244+
"int1",
4245+
DataType::Int(None)
4246+
),
4247+
OperateFunctionArg::with_name(
4248+
"int2",
4249+
DataType::Int(None)
4250+
),
4251+
]),
4252+
return_type: Some(DataType::Boolean),
4253+
language: Some("plpgsql".into()),
4254+
behavior: None,
4255+
called_on_null: None,
4256+
parallel: None,
4257+
function_body: Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
4258+
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> int2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
4259+
))),
4260+
if_not_exists: false,
4261+
using: None,
4262+
determinism_specifier: None,
4263+
options: None,
4264+
remote_connection: None,
4265+
})
4266+
);
4267+
}
4268+
41014269
#[test]
41024270
fn parse_create_function() {
41034271
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'";

0 commit comments

Comments
 (0)