Skip to content

Commit 6b8b504

Browse files
authored
feat: add rule to support ::regclass::oid cast (#288)
* feat: add rule to support ::regclass::oid cast * test: add pgadbc sql tests * fix: lint * feat: update sql and support schema
1 parent 8c99aea commit 6b8b504

3 files changed

Lines changed: 217 additions & 0 deletions

File tree

datafusion-pg-catalog/src/sql/parser.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use super::rules::RemoveSubqueryFromProjection;
1818
use super::rules::RemoveUnsupportedTypes;
1919
use super::rules::ResolveUnqualifiedIdentifer;
2020
use super::rules::RewriteArrayAnyAllOperation;
21+
use super::rules::RewriteRegclassCastToSubquery;
2122
use super::rules::SqlStatementRewriteRule;
2223

2324
const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[
@@ -228,6 +229,7 @@ impl PostgresCompatibilityParser {
228229
Arc::new(RewriteArrayAnyAllOperation),
229230
Arc::new(PrependUnqualifiedPgTableName),
230231
Arc::new(RemoveQualifier),
232+
Arc::new(RewriteRegclassCastToSubquery::new()),
231233
Arc::new(RemoveUnsupportedTypes::new()),
232234
Arc::new(FixArrayLiteral),
233235
Arc::new(CurrentUserVariableToSessionUserFunctionCall),

datafusion-pg-catalog/src/sql/rules.rs

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ use datafusion::sql::sqlparser::ast::Value;
3131
use datafusion::sql::sqlparser::ast::ValueWithSpan;
3232
use datafusion::sql::sqlparser::ast::VisitMut;
3333
use datafusion::sql::sqlparser::ast::VisitorMut;
34+
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
35+
use datafusion::sql::sqlparser::parser::Parser;
3436

3537
pub trait SqlStatementRewriteRule: Send + Sync + Debug {
3638
fn rewrite(&self, s: Statement) -> Statement;
@@ -382,6 +384,165 @@ impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
382384
}
383385
}
384386

387+
/// Rewrite regclass::oid cast to subquery
388+
///
389+
/// This rewrites patterns like `$1::regclass::oid` to
390+
/// `(SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)`
391+
#[derive(Debug)]
392+
pub struct RewriteRegclassCastToSubquery(Box<Query>);
393+
394+
impl Default for RewriteRegclassCastToSubquery {
395+
fn default() -> Self {
396+
Self::new()
397+
}
398+
}
399+
400+
impl RewriteRegclassCastToSubquery {
401+
pub fn new() -> Self {
402+
let sql = "SELECT c.oid
403+
FROM pg_catalog.pg_class c
404+
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
405+
CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) p
406+
WHERE n.nspname = COALESCE(
407+
CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END,
408+
current_schema()
409+
)
410+
AND c.relname = p.parts[-1]";
411+
let dialect = PostgreSqlDialect {};
412+
let query = Parser::parse_sql(&dialect, sql)
413+
.map(|mut stmts| {
414+
let stmt = stmts.remove(0);
415+
if let Statement::Query(query) = stmt {
416+
query
417+
} else {
418+
unreachable!()
419+
}
420+
})
421+
.expect("Failed to parse prepared query");
422+
Self(query)
423+
}
424+
}
425+
426+
struct RewriteRegclassCastToSubqueryVisitor(Box<Query>);
427+
428+
impl RewriteRegclassCastToSubqueryVisitor {
429+
pub fn new(query: Box<Query>) -> Self {
430+
Self(query)
431+
}
432+
433+
fn create_subquery(&self, expr: &Expr) -> Expr {
434+
struct PlaceholderReplacer(Expr);
435+
436+
impl VisitorMut for PlaceholderReplacer {
437+
type Break = ();
438+
439+
fn pre_visit_expr(&mut self, e: &mut Expr) -> ControlFlow<Self::Break> {
440+
if let Expr::Value(ValueWithSpan {
441+
value: Value::Placeholder(_placeholder),
442+
..
443+
}) = e
444+
{
445+
*e = self.0.clone();
446+
}
447+
ControlFlow::Continue(())
448+
}
449+
}
450+
451+
let mut query = self.0.clone();
452+
let mut replacer = PlaceholderReplacer(expr.clone());
453+
let _ = query.visit(&mut replacer);
454+
Expr::Subquery(query)
455+
}
456+
457+
fn is_regclass_to_oid_cast(&self, expr: &Expr) -> bool {
458+
if let Expr::Cast {
459+
kind,
460+
data_type,
461+
expr: inner_expr,
462+
format: _,
463+
} = expr
464+
{
465+
if *kind == CastKind::DoubleColon {
466+
let dt_lower = data_type.to_string().to_lowercase();
467+
if dt_lower == "oid" || dt_lower == "pg_catalog.oid" {
468+
return self.is_regclass_cast(inner_expr);
469+
}
470+
}
471+
}
472+
false
473+
}
474+
475+
fn is_regclass_cast(&self, expr: &Expr) -> bool {
476+
if let Expr::Cast {
477+
kind,
478+
data_type,
479+
expr: _,
480+
format: _,
481+
} = expr
482+
{
483+
if *kind == CastKind::DoubleColon {
484+
let dt_lower = data_type.to_string().to_lowercase();
485+
return dt_lower == "regclass" || dt_lower == "pg_catalog.regclass";
486+
}
487+
}
488+
false
489+
}
490+
491+
fn extract_inner_expr(&self, expr: &Expr) -> Option<Expr> {
492+
if let Expr::Cast {
493+
kind,
494+
data_type,
495+
expr: inner_expr,
496+
format: _,
497+
} = expr
498+
{
499+
if *kind == CastKind::DoubleColon {
500+
let dt_lower = data_type.to_string().to_lowercase();
501+
if dt_lower == "oid" || dt_lower == "pg_catalog.oid" {
502+
if let Expr::Cast {
503+
kind: inner_kind,
504+
data_type: inner_data_type,
505+
expr: inner_inner_expr,
506+
format: _,
507+
} = inner_expr.as_ref()
508+
{
509+
if *inner_kind == CastKind::DoubleColon {
510+
let inner_dt_lower = inner_data_type.to_string().to_lowercase();
511+
if inner_dt_lower == "regclass"
512+
|| inner_dt_lower == "pg_catalog.regclass"
513+
{
514+
return Some((**inner_inner_expr).clone());
515+
}
516+
}
517+
}
518+
}
519+
}
520+
}
521+
None
522+
}
523+
}
524+
525+
impl VisitorMut for RewriteRegclassCastToSubqueryVisitor {
526+
type Break = ();
527+
528+
fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
529+
if self.is_regclass_to_oid_cast(expr) {
530+
if let Some(inner_expr) = self.extract_inner_expr(expr) {
531+
*expr = self.create_subquery(&inner_expr);
532+
}
533+
}
534+
ControlFlow::Continue(())
535+
}
536+
}
537+
538+
impl SqlStatementRewriteRule for RewriteRegclassCastToSubquery {
539+
fn rewrite(&self, mut s: Statement) -> Statement {
540+
let mut visitor = RewriteRegclassCastToSubqueryVisitor::new(self.0.clone());
541+
let _ = s.visit(&mut visitor);
542+
s
543+
}
544+
}
545+
385546
/// Rewrite Postgres's ANY operator to array_contains
386547
#[derive(Debug)]
387548
pub struct RewriteArrayAnyAllOperation;
@@ -997,6 +1158,36 @@ mod tests {
9971158
);
9981159
}
9991160

1161+
#[test]
1162+
fn test_rewrite_regclass_cast_to_subquery() {
1163+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1164+
vec![Arc::new(RewriteRegclassCastToSubquery::new())];
1165+
1166+
assert_rewrite!(
1167+
&rules,
1168+
"SELECT $1::regclass::oid",
1169+
"SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1170+
);
1171+
1172+
assert_rewrite!(
1173+
&rules,
1174+
"SELECT $1::pg_catalog.regclass::oid",
1175+
"SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1176+
);
1177+
1178+
assert_rewrite!(
1179+
&rules,
1180+
"SELECT $1::pg_catalog.regclass::pg_catalog.oid",
1181+
"SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1182+
);
1183+
1184+
assert_rewrite!(
1185+
&rules,
1186+
"SELECT * FROM pg_catalog.pg_class WHERE oid = 't'::pg_catalog.regclass::pg_catalog.oid",
1187+
"SELECT * FROM pg_catalog.pg_class WHERE oid = (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident('t'::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1188+
);
1189+
}
1190+
10001191
#[test]
10011192
fn test_any_to_array_contains() {
10021193
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use pgwire::api::query::SimpleQueryHandler;
2+
3+
use datafusion_postgres::testing::*;
4+
5+
const PGADBC_QUERIES: &[&str] = &[
6+
"SELECT attname, atttypid FROM pg_catalog.pg_class AS cls INNER JOIN pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid INNER JOIN pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid WHERE attr.attnum >= 0 AND cls.oid = 'clubs'::regclass::oid ORDER BY attr.attnum",
7+
8+
9+
];
10+
11+
#[tokio::test]
12+
pub async fn test_pgadbc_metadata_sql() {
13+
env_logger::init();
14+
let service = setup_handlers();
15+
let mut client = MockClient::new();
16+
17+
for query in PGADBC_QUERIES {
18+
SimpleQueryHandler::do_query(&service, &mut client, query)
19+
.await
20+
.unwrap_or_else(|e| {
21+
panic!("failed to run sql:\n--------------\n {query}\n--------------\n{e}")
22+
});
23+
}
24+
}

0 commit comments

Comments
 (0)