Skip to content

Commit 4094fc8

Browse files
committed
Add SET configuration_parameter support for PostgreSQL functions.
1 parent eb495f2 commit 4094fc8

6 files changed

Lines changed: 113 additions & 7 deletions

File tree

src/ast/ddl.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,14 @@ use crate::ast::{
4343
},
4444
ArgMode, AttachedToken, CommentDef, ConditionalStatements, CreateFunctionBody,
4545
CreateFunctionUsing, CreateTableLikeKind, CreateTableOptions, CreateViewParams, DataType, Expr,
46-
FileFormat, FunctionBehavior, FunctionCalledOnNull, FunctionDesc, FunctionDeterminismSpecifier,
47-
FunctionParallel, FunctionSecurity, HiveDistributionStyle, HiveFormat, HiveIOFormat,
48-
HiveRowFormat, HiveSetLocation, Ident, InitializeKind, MySQLColumnPosition, ObjectName,
49-
OnCommit, OneOrManyWithParens, OperateFunctionArg, OrderByExpr, ProjectionSelect, Query,
50-
RefreshModeKind, RowAccessPolicy, SequenceOptions, Spanned, SqlOption,
51-
StorageSerializationPolicy, TableVersion, Tag, TriggerEvent, TriggerExecBody, TriggerObject,
52-
TriggerPeriod, TriggerReferencing, Value, ValueWithSpan, WrappedCollection,
46+
FileFormat, FunctionBehavior, FunctionCalledOnNull, FunctionDefinitionSetParam, FunctionDesc,
47+
FunctionDeterminismSpecifier, FunctionParallel, FunctionSecurity, HiveDistributionStyle,
48+
HiveFormat, HiveIOFormat, HiveRowFormat, HiveSetLocation, Ident, InitializeKind,
49+
MySQLColumnPosition, ObjectName, OnCommit, OneOrManyWithParens, OperateFunctionArg,
50+
OrderByExpr, ProjectionSelect, Query, RefreshModeKind, RowAccessPolicy, SequenceOptions,
51+
Spanned, SqlOption, StorageSerializationPolicy, TableVersion, Tag, TriggerEvent,
52+
TriggerExecBody, TriggerObject, TriggerPeriod, TriggerReferencing, Value, ValueWithSpan,
53+
WrappedCollection,
5354
};
5455
use crate::display_utils::{DisplayCommaSeparated, Indent, NewLine, SpaceOrNewline};
5556
use crate::keywords::Keyword;
@@ -3221,6 +3222,10 @@ pub struct CreateFunction {
32213222
///
32223223
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
32233224
pub security: Option<FunctionSecurity>,
3225+
/// SET configuration_parameter clauses
3226+
///
3227+
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
3228+
pub set_params: Vec<FunctionDefinitionSetParam>,
32243229
/// USING ... (Hive only)
32253230
pub using: Option<CreateFunctionUsing>,
32263231
/// Language used in a UDF definition.
@@ -3290,6 +3295,9 @@ impl fmt::Display for CreateFunction {
32903295
if let Some(security) = &self.security {
32913296
write!(f, " {security}")?;
32923297
}
3298+
for set_param in &self.set_params {
3299+
write!(f, " {set_param}")?;
3300+
}
32933301
if let Some(remote_connection) = &self.remote_connection {
32943302
write!(f, " REMOTE WITH CONNECTION {remote_connection}")?;
32953303
}

src/ast/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8789,6 +8789,42 @@ impl fmt::Display for FunctionSecurity {
87898789
}
87908790
}
87918791

8792+
/// Value for a SET configuration parameter in a CREATE FUNCTION statement.
8793+
///
8794+
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
8795+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
8796+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8797+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
8798+
pub enum FunctionSetValue {
8799+
/// SET param = value1, value2, ...
8800+
Values(Vec<Expr>),
8801+
/// SET param FROM CURRENT
8802+
FromCurrent,
8803+
}
8804+
8805+
/// A SET configuration_parameter clause in a CREATE FUNCTION statement.
8806+
///
8807+
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
8808+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
8809+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8810+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
8811+
pub struct FunctionDefinitionSetParam {
8812+
pub name: Ident,
8813+
pub value: FunctionSetValue,
8814+
}
8815+
8816+
impl fmt::Display for FunctionDefinitionSetParam {
8817+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
8818+
write!(f, "SET {} ", self.name)?;
8819+
match &self.value {
8820+
FunctionSetValue::Values(values) => {
8821+
write!(f, "= {}", display_comma_separated(values))
8822+
}
8823+
FunctionSetValue::FromCurrent => write!(f, "FROM CURRENT"),
8824+
}
8825+
}
8826+
}
8827+
87928828
/// These attributes describe the behavior of the function when called with a null argument.
87938829
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
87948830
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]

src/parser/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5209,6 +5209,7 @@ impl<'a> Parser<'a> {
52095209
security: Option<FunctionSecurity>,
52105210
}
52115211
let mut body = Body::default();
5212+
let mut set_params: Vec<FunctionDefinitionSetParam> = Vec::new();
52125213
loop {
52135214
fn ensure_not_set<T>(field: &Option<T>, name: &str) -> Result<(), ParserError> {
52145215
if field.is_some() {
@@ -5282,6 +5283,18 @@ impl<'a> Parser<'a> {
52825283
} else {
52835284
return self.expected("DEFINER or INVOKER", self.peek_token());
52845285
}
5286+
} else if self.parse_keyword(Keyword::SET) {
5287+
let name = self.parse_identifier()?;
5288+
let value = if self.parse_keywords(&[Keyword::FROM, Keyword::CURRENT]) {
5289+
FunctionSetValue::FromCurrent
5290+
} else {
5291+
if !self.consume_token(&Token::Eq) && !self.parse_keyword(Keyword::TO) {
5292+
return self.expected("= or TO", self.peek_token());
5293+
}
5294+
let values = self.parse_comma_separated(Parser::parse_expr)?;
5295+
FunctionSetValue::Values(values)
5296+
};
5297+
set_params.push(FunctionDefinitionSetParam { name, value });
52855298
} else if self.parse_keyword(Keyword::RETURN) {
52865299
ensure_not_set(&body.function_body, "RETURN")?;
52875300
body.function_body = Some(CreateFunctionBody::Return(self.parse_expr()?));
@@ -5301,6 +5314,7 @@ impl<'a> Parser<'a> {
53015314
called_on_null: body.called_on_null,
53025315
parallel: body.parallel,
53035316
security: body.security,
5317+
set_params,
53045318
language: body.language,
53055319
function_body: body.function_body,
53065320
if_not_exists: false,
@@ -5339,6 +5353,7 @@ impl<'a> Parser<'a> {
53395353
called_on_null: None,
53405354
parallel: None,
53415355
security: None,
5356+
set_params: vec![],
53425357
language: None,
53435358
determinism_specifier: None,
53445359
options: None,
@@ -5422,6 +5437,7 @@ impl<'a> Parser<'a> {
54225437
called_on_null: None,
54235438
parallel: None,
54245439
security: None,
5440+
set_params: vec![],
54255441
}))
54265442
}
54275443

@@ -5512,6 +5528,7 @@ impl<'a> Parser<'a> {
55125528
called_on_null: None,
55135529
parallel: None,
55145530
security: None,
5531+
set_params: vec![],
55155532
}))
55165533
}
55175534

tests/sqlparser_bigquery.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,6 +2295,7 @@ fn test_bigquery_create_function() {
22952295
called_on_null: None,
22962296
parallel: None,
22972297
security: None,
2298+
set_params: vec![],
22982299
})
22992300
);
23002301

tests/sqlparser_mssql.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ fn parse_create_function() {
267267
called_on_null: None,
268268
parallel: None,
269269
security: None,
270+
set_params: vec![],
270271
using: None,
271272
language: None,
272273
determinism_specifier: None,
@@ -441,6 +442,7 @@ fn parse_create_function_parameter_default_values() {
441442
called_on_null: None,
442443
parallel: None,
443444
security: None,
445+
set_params: vec![],
444446
using: None,
445447
language: None,
446448
determinism_specifier: None,

tests/sqlparser_postgres.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4263,6 +4263,7 @@ $$"#;
42634263
called_on_null: None,
42644264
parallel: None,
42654265
security: None,
4266+
set_params: vec![],
42664267
function_body: Some(CreateFunctionBody::AsBeforeOptions {
42674268
body: Expr::Value(
42684269
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF str1 <> str2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4305,6 +4306,7 @@ $$"#;
43054306
called_on_null: None,
43064307
parallel: None,
43074308
security: None,
4309+
set_params: vec![],
43084310
function_body: Some(CreateFunctionBody::AsBeforeOptions {
43094311
body: Expr::Value(
43104312
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> 0 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4351,6 +4353,7 @@ $$"#;
43514353
called_on_null: None,
43524354
parallel: None,
43534355
security: None,
4356+
set_params: vec![],
43544357
function_body: Some(CreateFunctionBody::AsBeforeOptions {
43554358
body: Expr::Value(
43564359
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF a <> b THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4397,6 +4400,7 @@ $$"#;
43974400
called_on_null: None,
43984401
parallel: None,
43994402
security: None,
4403+
set_params: vec![],
44004404
function_body: Some(CreateFunctionBody::AsBeforeOptions {
44014405
body: Expr::Value(
44024406
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> int2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4436,6 +4440,7 @@ $$"#;
44364440
called_on_null: None,
44374441
parallel: None,
44384442
security: None,
4443+
set_params: vec![],
44394444
function_body: Some(CreateFunctionBody::AsBeforeOptions {
44404445
body: Expr::Value(
44414446
(Value::DollarQuotedString(DollarQuotedString {
@@ -4478,6 +4483,7 @@ fn parse_create_function() {
44784483
called_on_null: Some(FunctionCalledOnNull::Strict),
44794484
parallel: Some(FunctionParallel::Safe),
44804485
security: None,
4486+
set_params: vec![],
44814487
function_body: Some(CreateFunctionBody::AsBeforeOptions {
44824488
body: Expr::Value(
44834489
(Value::SingleQuotedString("select $1 + $2;".into())).with_empty_span()
@@ -4529,6 +4535,40 @@ fn parse_create_function_with_security() {
45294535
}
45304536
}
45314537

4538+
#[test]
4539+
fn parse_create_function_with_set_params() {
4540+
let sql =
4541+
"CREATE FUNCTION test_fn() RETURNS void LANGUAGE sql SET search_path = auth, pg_temp, public AS $$ SELECT 1 $$";
4542+
match pg_and_generic().verified_stmt(sql) {
4543+
Statement::CreateFunction(CreateFunction { set_params, .. }) => {
4544+
assert_eq!(set_params.len(), 1);
4545+
assert_eq!(set_params[0].name.to_string(), "search_path");
4546+
}
4547+
_ => panic!("Expected CreateFunction"),
4548+
}
4549+
4550+
// Test multiple SET params
4551+
let sql2 =
4552+
"CREATE FUNCTION test_fn() RETURNS void LANGUAGE sql SET search_path = public SET statement_timeout = '5s' AS $$ SELECT 1 $$";
4553+
match pg_and_generic().verified_stmt(sql2) {
4554+
Statement::CreateFunction(CreateFunction { set_params, .. }) => {
4555+
assert_eq!(set_params.len(), 2);
4556+
}
4557+
_ => panic!("Expected CreateFunction"),
4558+
}
4559+
4560+
// Test FROM CURRENT
4561+
let sql3 =
4562+
"CREATE FUNCTION test_fn() RETURNS void LANGUAGE sql SET search_path FROM CURRENT AS $$ SELECT 1 $$";
4563+
match pg_and_generic().verified_stmt(sql3) {
4564+
Statement::CreateFunction(CreateFunction { set_params, .. }) => {
4565+
assert_eq!(set_params.len(), 1);
4566+
assert!(matches!(set_params[0].value, FunctionSetValue::FromCurrent));
4567+
}
4568+
_ => panic!("Expected CreateFunction"),
4569+
}
4570+
}
4571+
45324572
#[test]
45334573
fn parse_incorrect_create_function_parallel() {
45344574
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL PARALLEL BLAH AS 'select $1 + $2;'";
@@ -4558,6 +4598,7 @@ fn parse_create_function_c_with_module_pathname() {
45584598
called_on_null: None,
45594599
parallel: Some(FunctionParallel::Safe),
45604600
security: None,
4601+
set_params: vec![],
45614602
function_body: Some(CreateFunctionBody::AsBeforeOptions {
45624603
body: Expr::Value(
45634604
(Value::SingleQuotedString("MODULE_PATHNAME".into())).with_empty_span()
@@ -6184,6 +6225,7 @@ fn parse_trigger_related_functions() {
61846225
called_on_null: None,
61856226
parallel: None,
61866227
security: None,
6228+
set_params: vec![],
61876229
using: None,
61886230
language: Some(Ident::new("plpgsql")),
61896231
determinism_specifier: None,

0 commit comments

Comments
 (0)