Skip to content

Commit 2bf93a4

Browse files
authored
Support PARALLEL ... and for ..ON NULL INPUT ... to CREATE FUNCTION` (apache#1202)
1 parent 14b33ac commit 2bf93a4

4 files changed

Lines changed: 142 additions & 3 deletions

File tree

src/ast/mod.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5683,6 +5683,46 @@ impl fmt::Display for FunctionBehavior {
56835683
}
56845684
}
56855685

5686+
/// These attributes describe the behavior of the function when called with a null argument.
5687+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
5688+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
5689+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
5690+
pub enum FunctionCalledOnNull {
5691+
CalledOnNullInput,
5692+
ReturnsNullOnNullInput,
5693+
Strict,
5694+
}
5695+
5696+
impl fmt::Display for FunctionCalledOnNull {
5697+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
5698+
match self {
5699+
FunctionCalledOnNull::CalledOnNullInput => write!(f, "CALLED ON NULL INPUT"),
5700+
FunctionCalledOnNull::ReturnsNullOnNullInput => write!(f, "RETURNS NULL ON NULL INPUT"),
5701+
FunctionCalledOnNull::Strict => write!(f, "STRICT"),
5702+
}
5703+
}
5704+
}
5705+
5706+
/// If it is safe for PostgreSQL to call the function from multiple threads at once
5707+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
5708+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
5709+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
5710+
pub enum FunctionParallel {
5711+
Unsafe,
5712+
Restricted,
5713+
Safe,
5714+
}
5715+
5716+
impl fmt::Display for FunctionParallel {
5717+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
5718+
match self {
5719+
FunctionParallel::Unsafe => write!(f, "PARALLEL UNSAFE"),
5720+
FunctionParallel::Restricted => write!(f, "PARALLEL RESTRICTED"),
5721+
FunctionParallel::Safe => write!(f, "PARALLEL SAFE"),
5722+
}
5723+
}
5724+
}
5725+
56865726
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
56875727
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
56885728
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@@ -5703,7 +5743,7 @@ impl fmt::Display for FunctionDefinition {
57035743

57045744
/// Postgres specific feature.
57055745
///
5706-
/// See [Postgresdocs](https://www.postgresql.org/docs/15/sql-createfunction.html)
5746+
/// See [Postgres docs](https://www.postgresql.org/docs/15/sql-createfunction.html)
57075747
/// for more details
57085748
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
57095749
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -5713,6 +5753,10 @@ pub struct CreateFunctionBody {
57135753
pub language: Option<Ident>,
57145754
/// IMMUTABLE | STABLE | VOLATILE
57155755
pub behavior: Option<FunctionBehavior>,
5756+
/// CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT
5757+
pub called_on_null: Option<FunctionCalledOnNull>,
5758+
/// PARALLEL { UNSAFE | RESTRICTED | SAFE }
5759+
pub parallel: Option<FunctionParallel>,
57165760
/// AS 'definition'
57175761
///
57185762
/// Note that Hive's `AS class_name` is also parsed here.
@@ -5731,6 +5775,12 @@ impl fmt::Display for CreateFunctionBody {
57315775
if let Some(behavior) = &self.behavior {
57325776
write!(f, " {behavior}")?;
57335777
}
5778+
if let Some(called_on_null) = &self.called_on_null {
5779+
write!(f, " {called_on_null}")?;
5780+
}
5781+
if let Some(parallel) = &self.parallel {
5782+
write!(f, " {parallel}")?;
5783+
}
57345784
if let Some(definition) = &self.as_ {
57355785
write!(f, " AS {definition}")?;
57365786
}

src/keywords.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ define_keywords!(
353353
INITIALLY,
354354
INNER,
355355
INOUT,
356+
INPUT,
356357
INPUTFORMAT,
357358
INSENSITIVE,
358359
INSERT,
@@ -498,6 +499,7 @@ define_keywords!(
498499
OVERLAY,
499500
OVERWRITE,
500501
OWNED,
502+
PARALLEL,
501503
PARAMETER,
502504
PARQUET,
503505
PARTITION,
@@ -570,6 +572,7 @@ define_keywords!(
570572
RESPECT,
571573
RESTART,
572574
RESTRICT,
575+
RESTRICTED,
573576
RESULT,
574577
RESULTSET,
575578
RETAIN,
@@ -589,6 +592,7 @@ define_keywords!(
589592
ROW_NUMBER,
590593
RULE,
591594
RUN,
595+
SAFE,
592596
SAFE_CAST,
593597
SAVEPOINT,
594598
SCHEMA,
@@ -704,6 +708,7 @@ define_keywords!(
704708
UNLOGGED,
705709
UNNEST,
706710
UNPIVOT,
711+
UNSAFE,
707712
UNSIGNED,
708713
UNTIL,
709714
UPDATE,

src/parser/mod.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3437,6 +3437,46 @@ impl<'a> Parser<'a> {
34373437
} else if self.parse_keyword(Keyword::VOLATILE) {
34383438
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
34393439
body.behavior = Some(FunctionBehavior::Volatile);
3440+
} else if self.parse_keywords(&[
3441+
Keyword::CALLED,
3442+
Keyword::ON,
3443+
Keyword::NULL,
3444+
Keyword::INPUT,
3445+
]) {
3446+
ensure_not_set(
3447+
&body.called_on_null,
3448+
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
3449+
)?;
3450+
body.called_on_null = Some(FunctionCalledOnNull::CalledOnNullInput);
3451+
} else if self.parse_keywords(&[
3452+
Keyword::RETURNS,
3453+
Keyword::NULL,
3454+
Keyword::ON,
3455+
Keyword::NULL,
3456+
Keyword::INPUT,
3457+
]) {
3458+
ensure_not_set(
3459+
&body.called_on_null,
3460+
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
3461+
)?;
3462+
body.called_on_null = Some(FunctionCalledOnNull::ReturnsNullOnNullInput);
3463+
} else if self.parse_keyword(Keyword::STRICT) {
3464+
ensure_not_set(
3465+
&body.called_on_null,
3466+
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
3467+
)?;
3468+
body.called_on_null = Some(FunctionCalledOnNull::Strict);
3469+
} else if self.parse_keyword(Keyword::PARALLEL) {
3470+
ensure_not_set(&body.parallel, "PARALLEL { UNSAFE | RESTRICTED | SAFE }")?;
3471+
if self.parse_keyword(Keyword::UNSAFE) {
3472+
body.parallel = Some(FunctionParallel::Unsafe);
3473+
} else if self.parse_keyword(Keyword::RESTRICTED) {
3474+
body.parallel = Some(FunctionParallel::Restricted);
3475+
} else if self.parse_keyword(Keyword::SAFE) {
3476+
body.parallel = Some(FunctionParallel::Safe);
3477+
} else {
3478+
return self.expected("one of UNSAFE | RESTRICTED | SAFE", self.peek_token());
3479+
}
34403480
} else if self.parse_keyword(Keyword::RETURN) {
34413481
ensure_not_set(&body.return_, "RETURN")?;
34423482
body.return_ = Some(self.parse_expr()?);

tests/sqlparser_postgres.rs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3280,7 +3280,7 @@ fn parse_similar_to() {
32803280

32813281
#[test]
32823282
fn parse_create_function() {
3283-
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE AS 'select $1 + $2;'";
3283+
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'";
32843284
assert_eq!(
32853285
pg_and_generic().verified_stmt(sql),
32863286
Statement::CreateFunction {
@@ -3295,6 +3295,8 @@ fn parse_create_function() {
32953295
params: CreateFunctionBody {
32963296
language: Some("SQL".into()),
32973297
behavior: Some(FunctionBehavior::Immutable),
3298+
called_on_null: Some(FunctionCalledOnNull::Strict),
3299+
parallel: Some(FunctionParallel::Safe),
32983300
as_: Some(FunctionDefinition::SingleQuotedDef(
32993301
"select $1 + $2;".into()
33003302
)),
@@ -3303,7 +3305,7 @@ fn parse_create_function() {
33033305
}
33043306
);
33053307

3306-
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURN a + b";
3308+
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT PARALLEL RESTRICTED RETURN a + b";
33073309
assert_eq!(
33083310
pg_and_generic().verified_stmt(sql),
33093311
Statement::CreateFunction {
@@ -3323,6 +3325,40 @@ fn parse_create_function() {
33233325
params: CreateFunctionBody {
33243326
language: Some("SQL".into()),
33253327
behavior: Some(FunctionBehavior::Immutable),
3328+
called_on_null: Some(FunctionCalledOnNull::ReturnsNullOnNullInput),
3329+
parallel: Some(FunctionParallel::Restricted),
3330+
return_: Some(Expr::BinaryOp {
3331+
left: Box::new(Expr::Identifier("a".into())),
3332+
op: BinaryOperator::Plus,
3333+
right: Box::new(Expr::Identifier("b".into())),
3334+
}),
3335+
..Default::default()
3336+
},
3337+
}
3338+
);
3339+
3340+
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL STABLE CALLED ON NULL INPUT PARALLEL UNSAFE RETURN a + b";
3341+
assert_eq!(
3342+
pg_and_generic().verified_stmt(sql),
3343+
Statement::CreateFunction {
3344+
or_replace: true,
3345+
temporary: false,
3346+
name: ObjectName(vec![Ident::new("add")]),
3347+
args: Some(vec![
3348+
OperateFunctionArg::with_name("a", DataType::Integer(None)),
3349+
OperateFunctionArg {
3350+
mode: Some(ArgMode::In),
3351+
name: Some("b".into()),
3352+
data_type: DataType::Integer(None),
3353+
default_expr: Some(Expr::Value(Value::Number("1".parse().unwrap(), false))),
3354+
}
3355+
]),
3356+
return_type: Some(DataType::Integer(None)),
3357+
params: CreateFunctionBody {
3358+
language: Some("SQL".into()),
3359+
behavior: Some(FunctionBehavior::Stable),
3360+
called_on_null: Some(FunctionCalledOnNull::CalledOnNullInput),
3361+
parallel: Some(FunctionParallel::Unsafe),
33263362
return_: Some(Expr::BinaryOp {
33273363
left: Box::new(Expr::Identifier("a".into())),
33283364
op: BinaryOperator::Plus,
@@ -3348,6 +3384,8 @@ fn parse_create_function() {
33483384
params: CreateFunctionBody {
33493385
language: Some("plpgsql".into()),
33503386
behavior: None,
3387+
called_on_null: None,
3388+
parallel: None,
33513389
return_: None,
33523390
as_: Some(FunctionDefinition::DoubleDollarDef(
33533391
" BEGIN RETURN i + 1; END; ".into()
@@ -3358,6 +3396,12 @@ fn parse_create_function() {
33583396
);
33593397
}
33603398

3399+
#[test]
3400+
fn parse_incorrect_create_function_parallel() {
3401+
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL PARALLEL BLAH AS 'select $1 + $2;'";
3402+
assert!(pg().parse_sql_statements(sql).is_err());
3403+
}
3404+
33613405
#[test]
33623406
fn parse_drop_function() {
33633407
let sql = "DROP FUNCTION IF EXISTS test_func";

0 commit comments

Comments
 (0)