Skip to content

Commit 8c99aea

Browse files
authored
feat: fallback to casted type for parameter type inference (#291)
1 parent 305234a commit 8c99aea

3 files changed

Lines changed: 72 additions & 7 deletions

File tree

datafusion-postgres/src/handlers.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
1919
use pgwire::error::{PgWireError, PgWireResult};
2020
use pgwire::types::format::FormatOptions;
2121

22-
use crate::client;
2322
use crate::hooks::set_show::SetShowHook;
2423
use crate::hooks::transactions::TransactionStatementHook;
2524
use crate::hooks::QueryHook;
25+
use crate::{client, planner};
2626
use arrow_pg::datatypes::df;
2727
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
2828
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
@@ -215,8 +215,7 @@ impl ExtendedQueryHandler for DfSessionService {
215215
if !self.query_hooks.is_empty() {
216216
if let (_, Some((statement, plan))) = &portal.statement.statement {
217217
// TODO: in the case where query hooks all return None, we do the param handling again later.
218-
let param_types = plan
219-
.get_parameter_types()
218+
let param_types = planner::get_inferred_parameter_types(plan)
220219
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
221220

222221
let param_values: ParamValues =
@@ -240,8 +239,7 @@ impl ExtendedQueryHandler for DfSessionService {
240239
}
241240

242241
if let (_, Some((statement, plan))) = &portal.statement.statement {
243-
let param_types = plan
244-
.get_parameter_types()
242+
let param_types = planner::get_inferred_parameter_types(plan)
245243
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
246244

247245
let param_values =
@@ -381,8 +379,7 @@ impl QueryParser for Parser {
381379

382380
fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
383381
if let (_, Some((_, plan))) = stmt {
384-
let params = plan
385-
.get_parameter_types()
382+
let params = planner::get_inferred_parameter_types(plan)
386383
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
387384

388385
let mut param_types = Vec::with_capacity(params.len());

datafusion-postgres/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ pub mod auth;
22
pub(crate) mod client;
33
mod handlers;
44
pub mod hooks;
5+
mod planner;
56
#[cfg(any(test, debug_assertions))]
67
pub mod testing;
78

datafusion-postgres/src/planner.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
use std::collections::{HashMap, HashSet};
2+
3+
use datafusion::arrow::datatypes::DataType;
4+
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
5+
use datafusion::error::Result;
6+
use datafusion::logical_expr::LogicalPlan;
7+
use datafusion::prelude::Expr;
8+
9+
fn extract_placeholder_cast_types(plan: &LogicalPlan) -> Result<HashMap<String, Option<DataType>>> {
10+
let mut placeholder_types = HashMap::new();
11+
let mut casted_placeholders = HashSet::new();
12+
13+
plan.apply(|node| {
14+
for expr in node.expressions() {
15+
let _ = expr.apply(|e| {
16+
if let Expr::Cast(cast) = e {
17+
if let Expr::Placeholder(ph) = &*cast.expr {
18+
placeholder_types.insert(ph.id.clone(), Some(cast.data_type.clone()));
19+
casted_placeholders.insert(ph.id.clone());
20+
}
21+
}
22+
23+
if let Expr::Placeholder(ph) = e {
24+
if !casted_placeholders.contains(&ph.id)
25+
&& !placeholder_types.contains_key(&ph.id)
26+
{
27+
placeholder_types.insert(ph.id.clone(), None);
28+
}
29+
}
30+
31+
Ok(TreeNodeRecursion::Continue)
32+
});
33+
}
34+
Ok(TreeNodeRecursion::Continue)
35+
})?;
36+
37+
Ok(placeholder_types)
38+
}
39+
40+
pub fn get_inferred_parameter_types(
41+
plan: &LogicalPlan,
42+
) -> Result<HashMap<String, Option<DataType>>> {
43+
let param_types = plan.get_parameter_types()?;
44+
45+
let has_none = param_types.values().any(|v| v.is_none());
46+
47+
if !has_none {
48+
Ok(param_types)
49+
} else {
50+
let cast_types = extract_placeholder_cast_types(plan)?;
51+
52+
let mut merged = param_types;
53+
54+
for (id, opt_type) in cast_types {
55+
merged
56+
.entry(id)
57+
.and_modify(|existing| {
58+
if existing.is_none() {
59+
*existing = opt_type.clone();
60+
}
61+
})
62+
.or_insert(opt_type);
63+
}
64+
65+
Ok(merged)
66+
}
67+
}

0 commit comments

Comments
 (0)