Skip to content

Commit 8ff84e8

Browse files
authored
feat: improve subquery projection rule (#294)
* feat: improve subquery projection rule * chore: revert unnecessary change
1 parent 6b8b504 commit 8ff84e8

1 file changed

Lines changed: 150 additions & 5 deletions

File tree

datafusion-pg-catalog/src/sql/rules.rs

Lines changed: 150 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use datafusion::sql::sqlparser::ast::FunctionArgExpr;
1414
use datafusion::sql::sqlparser::ast::FunctionArgumentList;
1515
use datafusion::sql::sqlparser::ast::FunctionArguments;
1616
use datafusion::sql::sqlparser::ast::Ident;
17+
use datafusion::sql::sqlparser::ast::LimitClause;
1718
use datafusion::sql::sqlparser::ast::ObjectName;
1819
use datafusion::sql::sqlparser::ast::ObjectNamePart;
1920
use datafusion::sql::sqlparser::ast::OrderByKind;
@@ -30,6 +31,7 @@ use datafusion::sql::sqlparser::ast::UnaryOperator;
3031
use datafusion::sql::sqlparser::ast::Value;
3132
use datafusion::sql::sqlparser::ast::ValueWithSpan;
3233
use datafusion::sql::sqlparser::ast::VisitMut;
34+
use datafusion::sql::sqlparser::ast::Visitor;
3335
use datafusion::sql::sqlparser::ast::VisitorMut;
3436
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
3537
use datafusion::sql::sqlparser::parser::Parser;
@@ -937,12 +939,84 @@ impl SqlStatementRewriteRule for FixCollate {
937939
}
938940
}
939941

940-
/// Datafusion doesn't support subquery on projection
942+
/// A processor to replace unsupported subquery from projection with NULL.
943+
///
944+
/// It will also add `LIMIT 1` to supported subquery to ensure it returns scalar
945+
/// value.
941946
#[derive(Debug)]
942947
pub struct RemoveSubqueryFromProjection;
943948

944949
struct RemoveSubqueryFromProjectionVisitor;
945950

951+
impl RemoveSubqueryFromProjectionVisitor {
952+
fn has_correlation(&self, query: &Query) -> bool {
953+
if let SetExpr::Select(select) = &*query.body {
954+
let table_aliases: HashSet<String> = select
955+
.from
956+
.iter()
957+
.flat_map(|twj| {
958+
let mut aliases = HashSet::new();
959+
Self::collect_table_aliases_from_table_factor(&twj.relation, &mut aliases);
960+
for join in &twj.joins {
961+
Self::collect_table_aliases_from_table_factor(&join.relation, &mut aliases);
962+
}
963+
aliases
964+
})
965+
.collect();
966+
967+
let mut has_correlation = false;
968+
let mut visitor = CorrelationCheckVisitor(&mut has_correlation, &table_aliases);
969+
let _ = datafusion::logical_expr::sqlparser::ast::Visit::visit(query, &mut visitor);
970+
has_correlation
971+
} else {
972+
false
973+
}
974+
}
975+
976+
fn has_limit(&self, query: &Query) -> bool {
977+
query.limit_clause.is_some() || query.fetch.is_some()
978+
}
979+
980+
fn collect_table_aliases_from_table_factor(
981+
table_factor: &TableFactor,
982+
aliases: &mut HashSet<String>,
983+
) {
984+
if let TableFactor::Table {
985+
alias: Some(alias), ..
986+
} = table_factor
987+
{
988+
aliases.insert(alias.name.value.clone());
989+
}
990+
}
991+
}
992+
993+
struct CorrelationCheckVisitor<'a>(&'a mut bool, &'a HashSet<String>);
994+
995+
impl Visitor for CorrelationCheckVisitor<'_> {
996+
type Break = ();
997+
998+
fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
999+
match expr {
1000+
Expr::Value(ValueWithSpan {
1001+
value: Value::Placeholder(_placeholder),
1002+
..
1003+
}) => {
1004+
*self.0 = true;
1005+
}
1006+
Expr::CompoundIdentifier(idents) => {
1007+
if !idents.is_empty() {
1008+
let table_name = &idents[0].value;
1009+
if !self.1.contains(table_name) {
1010+
*self.0 = true;
1011+
}
1012+
}
1013+
}
1014+
_ => {}
1015+
}
1016+
ControlFlow::Continue(())
1017+
}
1018+
}
1019+
9461020
impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
9471021
type Break = ();
9481022

@@ -951,13 +1025,33 @@ impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
9511025
for projection in &mut select.projection {
9521026
match projection {
9531027
SelectItem::UnnamedExpr(expr) => {
954-
if let Expr::Subquery(_) = expr {
955-
*expr = Expr::Value(Value::Null.with_empty_span());
1028+
if let Expr::Subquery(subquery) = expr {
1029+
if self.has_correlation(subquery) {
1030+
*expr = Expr::Value(Value::Null.with_empty_span());
1031+
} else if !self.has_limit(subquery) {
1032+
subquery.limit_clause = Some(LimitClause::LimitOffset {
1033+
limit: Some(Expr::Value(
1034+
Value::Number("1".to_string(), false).with_empty_span(),
1035+
)),
1036+
offset: None,
1037+
limit_by: vec![],
1038+
});
1039+
}
9561040
}
9571041
}
9581042
SelectItem::ExprWithAlias { expr, .. } => {
959-
if let Expr::Subquery(_) = expr {
960-
*expr = Expr::Value(Value::Null.with_empty_span());
1043+
if let Expr::Subquery(subquery) = expr {
1044+
if self.has_correlation(subquery) {
1045+
*expr = Expr::Value(Value::Null.with_empty_span());
1046+
} else if !self.has_limit(subquery) {
1047+
subquery.limit_clause = Some(LimitClause::LimitOffset {
1048+
limit: Some(Expr::Value(
1049+
Value::Number("1".to_string(), false).with_empty_span(),
1050+
)),
1051+
offset: None,
1052+
limit_by: vec![],
1053+
});
1054+
}
9611055
}
9621056
}
9631057
_ => {}
@@ -1309,6 +1403,57 @@ mod tests {
13091403
"SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
13101404
}
13111405

1406+
#[test]
1407+
fn test_keep_simple_aggregated_subquery() {
1408+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1409+
vec![Arc::new(RemoveSubqueryFromProjection)];
1410+
1411+
assert_rewrite!(&rules,
1412+
"SELECT id, (SELECT COUNT(*) FROM pg_catalog.pg_attribute) AS attr_count FROM pg_catalog.pg_class",
1413+
"SELECT id, (SELECT COUNT(*) FROM pg_catalog.pg_attribute LIMIT 1) AS attr_count FROM pg_catalog.pg_class"
1414+
);
1415+
}
1416+
1417+
#[test]
1418+
fn test_remove_correlated_subquery() {
1419+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1420+
vec![Arc::new(RemoveSubqueryFromProjection)];
1421+
1422+
assert_rewrite!(&rules,
1423+
"SELECT a.attname, (SELECT COUNT(*) FROM pg_catalog.pg_attribute WHERE attrelid = a.oid) AS count FROM pg_catalog.pg_attribute a",
1424+
"SELECT a.attname, NULL AS count FROM pg_catalog.pg_attribute AS a"
1425+
);
1426+
}
1427+
1428+
#[test]
1429+
fn test_remove_non_aggregated_subquery() {
1430+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1431+
vec![Arc::new(RemoveSubqueryFromProjection)];
1432+
1433+
assert_rewrite!(&rules,
1434+
"SELECT id, (SELECT attname FROM pg_catalog.pg_attribute LIMIT 1) AS first_attr FROM pg_catalog.pg_class",
1435+
"SELECT id, (SELECT attname FROM pg_catalog.pg_attribute LIMIT 1) AS first_attr FROM pg_catalog.pg_class"
1436+
);
1437+
}
1438+
1439+
#[test]
1440+
fn test_keep_simple_scalar_subquery() {
1441+
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1442+
vec![Arc::new(RemoveSubqueryFromProjection)];
1443+
1444+
assert_rewrite!(
1445+
&rules,
1446+
"SELECT (SELECT 1) AS constant",
1447+
"SELECT (SELECT 1 LIMIT 1) AS constant"
1448+
);
1449+
1450+
assert_rewrite!(
1451+
&rules,
1452+
"SELECT (SELECT 'value') AS str_val",
1453+
"SELECT (SELECT 'value' LIMIT 1) AS str_val"
1454+
);
1455+
}
1456+
13121457
#[test]
13131458
fn test_version_rewrite() {
13141459
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];

0 commit comments

Comments
 (0)