Skip to content

Commit eb495f2

Browse files
committed
Add SECURITY DEFINER/INVOKER support for PostgreSQL functions.
1 parent 4954465 commit eb495f2

7 files changed

Lines changed: 82 additions & 9 deletions

File tree

src/ast/ddl.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ use crate::ast::{
4444
ArgMode, AttachedToken, CommentDef, ConditionalStatements, CreateFunctionBody,
4545
CreateFunctionUsing, CreateTableLikeKind, CreateTableOptions, CreateViewParams, DataType, Expr,
4646
FileFormat, FunctionBehavior, FunctionCalledOnNull, FunctionDesc, FunctionDeterminismSpecifier,
47-
FunctionParallel, HiveDistributionStyle, HiveFormat, HiveIOFormat, HiveRowFormat,
48-
HiveSetLocation, Ident, InitializeKind, MySQLColumnPosition, ObjectName, OnCommit,
49-
OneOrManyWithParens, OperateFunctionArg, OrderByExpr, ProjectionSelect, Query, RefreshModeKind,
50-
RowAccessPolicy, SequenceOptions, Spanned, SqlOption, StorageSerializationPolicy, TableVersion,
51-
Tag, TriggerEvent, TriggerExecBody, TriggerObject, TriggerPeriod, TriggerReferencing, Value,
52-
ValueWithSpan, WrappedCollection,
47+
FunctionParallel, FunctionSecurity, HiveDistributionStyle, HiveFormat, HiveIOFormat,
48+
HiveRowFormat, HiveSetLocation, Ident, InitializeKind, MySQLColumnPosition, ObjectName,
49+
OnCommit, OneOrManyWithParens, OperateFunctionArg, OrderByExpr, ProjectionSelect, Query,
50+
RefreshModeKind, RowAccessPolicy, SequenceOptions, Spanned, SqlOption,
51+
StorageSerializationPolicy, TableVersion, Tag, TriggerEvent, TriggerExecBody, TriggerObject,
52+
TriggerPeriod, TriggerReferencing, Value, ValueWithSpan, WrappedCollection,
5353
};
5454
use crate::display_utils::{DisplayCommaSeparated, Indent, NewLine, SpaceOrNewline};
5555
use crate::keywords::Keyword;
@@ -3217,6 +3217,10 @@ pub struct CreateFunction {
32173217
///
32183218
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
32193219
pub parallel: Option<FunctionParallel>,
3220+
/// SECURITY { DEFINER | INVOKER }
3221+
///
3222+
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
3223+
pub security: Option<FunctionSecurity>,
32203224
/// USING ... (Hive only)
32213225
pub using: Option<CreateFunctionUsing>,
32223226
/// Language used in a UDF definition.
@@ -3283,6 +3287,9 @@ impl fmt::Display for CreateFunction {
32833287
if let Some(parallel) = &self.parallel {
32843288
write!(f, " {parallel}")?;
32853289
}
3290+
if let Some(security) = &self.security {
3291+
write!(f, " {security}")?;
3292+
}
32863293
if let Some(remote_connection) = &self.remote_connection {
32873294
write!(f, " REMOTE WITH CONNECTION {remote_connection}")?;
32883295
}

src/ast/mod.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8771,6 +8771,24 @@ impl fmt::Display for FunctionBehavior {
87718771
}
87728772
}
87738773

8774+
/// Specifies whether the function is SECURITY DEFINER or SECURITY INVOKER.
8775+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
8776+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8777+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
8778+
pub enum FunctionSecurity {
8779+
Definer,
8780+
Invoker,
8781+
}
8782+
8783+
impl fmt::Display for FunctionSecurity {
8784+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
8785+
match self {
8786+
FunctionSecurity::Definer => write!(f, "SECURITY DEFINER"),
8787+
FunctionSecurity::Invoker => write!(f, "SECURITY INVOKER"),
8788+
}
8789+
}
8790+
}
8791+
87748792
/// These attributes describe the behavior of the function when called with a null argument.
87758793
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
87768794
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]

src/ast/spans.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,8 @@ impl Spanned for CreateTable {
553553
cluster_by: _, // todo, BigQuery specific
554554
clustered_by: _, // todo, Hive specific
555555
inherits: _, // todo, PostgreSQL specific
556-
partition_of: _, // todo, PostgreSQL specific
557-
for_values: _, // todo, PostgreSQL specific
556+
partition_of,
557+
for_values,
558558
strict: _, // bool
559559
copy_grants: _, // bool
560560
enable_schema_evolution: _, // bool
@@ -585,7 +585,9 @@ impl Spanned for CreateTable {
585585
.chain(columns.iter().map(|i| i.span()))
586586
.chain(constraints.iter().map(|i| i.span()))
587587
.chain(query.iter().map(|i| i.span()))
588-
.chain(clone.iter().map(|i| i.span())),
588+
.chain(clone.iter().map(|i| i.span()))
589+
.chain(partition_of.iter().map(|i| i.span()))
590+
.chain(for_values.iter().map(|i| i.span())),
589591
)
590592
}
591593
}

src/parser/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5206,6 +5206,7 @@ impl<'a> Parser<'a> {
52065206
function_body: Option<CreateFunctionBody>,
52075207
called_on_null: Option<FunctionCalledOnNull>,
52085208
parallel: Option<FunctionParallel>,
5209+
security: Option<FunctionSecurity>,
52095210
}
52105211
let mut body = Body::default();
52115212
loop {
@@ -5272,6 +5273,15 @@ impl<'a> Parser<'a> {
52725273
} else {
52735274
return self.expected("one of UNSAFE | RESTRICTED | SAFE", self.peek_token());
52745275
}
5276+
} else if self.parse_keyword(Keyword::SECURITY) {
5277+
ensure_not_set(&body.security, "SECURITY { DEFINER | INVOKER }")?;
5278+
if self.parse_keyword(Keyword::DEFINER) {
5279+
body.security = Some(FunctionSecurity::Definer);
5280+
} else if self.parse_keyword(Keyword::INVOKER) {
5281+
body.security = Some(FunctionSecurity::Invoker);
5282+
} else {
5283+
return self.expected("DEFINER or INVOKER", self.peek_token());
5284+
}
52755285
} else if self.parse_keyword(Keyword::RETURN) {
52765286
ensure_not_set(&body.function_body, "RETURN")?;
52775287
body.function_body = Some(CreateFunctionBody::Return(self.parse_expr()?));
@@ -5290,6 +5300,7 @@ impl<'a> Parser<'a> {
52905300
behavior: body.behavior,
52915301
called_on_null: body.called_on_null,
52925302
parallel: body.parallel,
5303+
security: body.security,
52935304
language: body.language,
52945305
function_body: body.function_body,
52955306
if_not_exists: false,
@@ -5327,6 +5338,7 @@ impl<'a> Parser<'a> {
53275338
behavior: None,
53285339
called_on_null: None,
53295340
parallel: None,
5341+
security: None,
53305342
language: None,
53315343
determinism_specifier: None,
53325344
options: None,
@@ -5409,6 +5421,7 @@ impl<'a> Parser<'a> {
54095421
behavior: None,
54105422
called_on_null: None,
54115423
parallel: None,
5424+
security: None,
54125425
}))
54135426
}
54145427

@@ -5498,6 +5511,7 @@ impl<'a> Parser<'a> {
54985511
behavior: None,
54995512
called_on_null: None,
55005513
parallel: None,
5514+
security: None,
55015515
}))
55025516
}
55035517

tests/sqlparser_bigquery.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,6 +2294,7 @@ fn test_bigquery_create_function() {
22942294
remote_connection: None,
22952295
called_on_null: None,
22962296
parallel: None,
2297+
security: None,
22972298
})
22982299
);
22992300

tests/sqlparser_mssql.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ fn parse_create_function() {
266266
behavior: None,
267267
called_on_null: None,
268268
parallel: None,
269+
security: None,
269270
using: None,
270271
language: None,
271272
determinism_specifier: None,
@@ -439,6 +440,7 @@ fn parse_create_function_parameter_default_values() {
439440
behavior: None,
440441
called_on_null: None,
441442
parallel: None,
443+
security: None,
442444
using: None,
443445
language: None,
444446
determinism_specifier: None,

tests/sqlparser_postgres.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4262,6 +4262,7 @@ $$"#;
42624262
behavior: None,
42634263
called_on_null: None,
42644264
parallel: None,
4265+
security: None,
42654266
function_body: Some(CreateFunctionBody::AsBeforeOptions {
42664267
body: Expr::Value(
42674268
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF str1 <> str2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4303,6 +4304,7 @@ $$"#;
43034304
behavior: None,
43044305
called_on_null: None,
43054306
parallel: None,
4307+
security: None,
43064308
function_body: Some(CreateFunctionBody::AsBeforeOptions {
43074309
body: Expr::Value(
43084310
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> 0 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4348,6 +4350,7 @@ $$"#;
43484350
behavior: None,
43494351
called_on_null: None,
43504352
parallel: None,
4353+
security: None,
43514354
function_body: Some(CreateFunctionBody::AsBeforeOptions {
43524355
body: Expr::Value(
43534356
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF a <> b THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4393,6 +4396,7 @@ $$"#;
43934396
behavior: None,
43944397
called_on_null: None,
43954398
parallel: None,
4399+
security: None,
43964400
function_body: Some(CreateFunctionBody::AsBeforeOptions {
43974401
body: Expr::Value(
43984402
(Value::DollarQuotedString(DollarQuotedString {value: "\nBEGIN\n IF int1 <> int2 THEN\n RETURN TRUE;\n ELSE\n RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
@@ -4431,6 +4435,7 @@ $$"#;
44314435
behavior: None,
44324436
called_on_null: None,
44334437
parallel: None,
4438+
security: None,
44344439
function_body: Some(CreateFunctionBody::AsBeforeOptions {
44354440
body: Expr::Value(
44364441
(Value::DollarQuotedString(DollarQuotedString {
@@ -4472,6 +4477,7 @@ fn parse_create_function() {
44724477
behavior: Some(FunctionBehavior::Immutable),
44734478
called_on_null: Some(FunctionCalledOnNull::Strict),
44744479
parallel: Some(FunctionParallel::Safe),
4480+
security: None,
44754481
function_body: Some(CreateFunctionBody::AsBeforeOptions {
44764482
body: Expr::Value(
44774483
(Value::SingleQuotedString("select $1 + $2;".into())).with_empty_span()
@@ -4502,6 +4508,27 @@ fn parse_create_function_detailed() {
45024508
);
45034509
}
45044510

4511+
#[test]
4512+
fn parse_create_function_with_security() {
4513+
let sql =
4514+
"CREATE FUNCTION test_fn() RETURNS void LANGUAGE sql SECURITY DEFINER AS $$ SELECT 1 $$";
4515+
match pg_and_generic().verified_stmt(sql) {
4516+
Statement::CreateFunction(CreateFunction { security, .. }) => {
4517+
assert_eq!(security, Some(FunctionSecurity::Definer));
4518+
}
4519+
_ => panic!("Expected CreateFunction"),
4520+
}
4521+
4522+
let sql2 =
4523+
"CREATE FUNCTION test_fn() RETURNS void LANGUAGE sql SECURITY INVOKER AS $$ SELECT 1 $$";
4524+
match pg_and_generic().verified_stmt(sql2) {
4525+
Statement::CreateFunction(CreateFunction { security, .. }) => {
4526+
assert_eq!(security, Some(FunctionSecurity::Invoker));
4527+
}
4528+
_ => panic!("Expected CreateFunction"),
4529+
}
4530+
}
4531+
45054532
#[test]
45064533
fn parse_incorrect_create_function_parallel() {
45074534
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL PARALLEL BLAH AS 'select $1 + $2;'";
@@ -4530,6 +4557,7 @@ fn parse_create_function_c_with_module_pathname() {
45304557
behavior: Some(FunctionBehavior::Immutable),
45314558
called_on_null: None,
45324559
parallel: Some(FunctionParallel::Safe),
4560+
security: None,
45334561
function_body: Some(CreateFunctionBody::AsBeforeOptions {
45344562
body: Expr::Value(
45354563
(Value::SingleQuotedString("MODULE_PATHNAME".into())).with_empty_span()
@@ -6155,6 +6183,7 @@ fn parse_trigger_related_functions() {
61556183
behavior: None,
61566184
called_on_null: None,
61576185
parallel: None,
6186+
security: None,
61586187
using: None,
61596188
language: Some(Ident::new("plpgsql")),
61606189
determinism_specifier: None,

0 commit comments

Comments
 (0)