Skip to content

Commit 718a833

Browse files
ParameterStatus on SET statements (#293)
* test: capture backend messages in MockClient * test: add failing test for ParameterStatus on SET statements * feat: add parameter_status_key_for_set helper in set_show hook * feat: send ParameterStatus on successful SET statements * test: verify ParameterStatus key-value for all SET variable types * style: apply cargo fmt * fix: use canonical PostgreSQL casing for ParameterStatus names DateStyle and IntervalStyle must match the casing sent during startup. * fix: mistake in timezone. * test: add integration tests for ParameterStatus sent via Sink Merge unit tests for parameter_status_key_for_set into single table-driven test. Add handler-level tests that call do_query() and assert ParameterStatus messages appear in MockClient's sent_messages, closing the gap where the actual feed() path was untested. Remove standalone timezone SingleAssignment case from parameter_status_key_for_set (SET TIME ZONE uses SetTimeZone). * chore: cargo fmt * refactor: move ParameterStatus logic into QueryHook return type Address PR review feedback: ParameterStatus computation now lives in try_respond_set_statements instead of handlers.rs. QueryHook trait returns HookOutput (Response + optional ParameterStatus data), handler just forwards it to the client Sink. * refactor: replace HookOutput tuple with #[non_exhaustive] builder struct Replace bare type aliases (ParameterStatusChange, HookOutput) with a proper struct supporting named fields and builder pattern. This makes future additions (notices, multiple ParameterStatus) non-breaking. * refactor: introduce HookClient supertrait to give hooks direct Sink access Replace the HookOutput return-type approach (where hooks returned ParameterStatus data for the handler to send) with an adapter pattern that lets hooks send messages directly through the client connection. - Add HookClient trait (supertrait of ClientInfo) with send_parameter_status() - Add PgHookClient<C> adapter that bridges generic pgwire client to dyn HookClient - Simplify handler loop: no more post-hook ParameterStatus forwarding - Remove HookOutput struct; hooks return Response directly - Update all hooks (SetShow, Transactions, Permissions) and tests * refactor: replace PgHookClient adapter with blanket impl for HookClient Use a blanket `impl<S> HookClient for S where S: ClientInfo + Sink<...>` so any pgwire client automatically implements HookClient. This removes the PgHookClient wrapper struct and ~110 lines of manual ClientInfo delegation and Sink forwarding. Also generalizes send_parameter_status to send_message, allowing hooks to send any PgWireBackendMessage.
1 parent 86325a6 commit 718a833

6 files changed

Lines changed: 284 additions & 49 deletions

File tree

datafusion-postgres/src/handlers.rs

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use pgwire::api::results::{FieldInfo, Response, Tag};
1717
use pgwire::api::stmt::QueryParser;
1818
use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
1919
use pgwire::error::{PgWireError, PgWireResult};
20+
use pgwire::messages::PgWireBackendMessage;
2021
use pgwire::types::format::FormatOptions;
2122

2223
use crate::hooks::set_show::SetShowHook;
@@ -119,10 +120,11 @@ impl DfSessionService {
119120
impl SimpleQueryHandler for DfSessionService {
120121
async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
121122
where
122-
C: ClientInfo + Unpin + Send + Sync,
123+
C: ClientInfo + futures::Sink<PgWireBackendMessage> + Unpin + Send + Sync,
124+
C::Error: std::fmt::Debug,
125+
PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
123126
{
124-
log::debug!("Received query: {query}"); // Log the query for debugging
125-
127+
log::debug!("Received query: {query}");
126128
let statements = self
127129
.parser
128130
.sql_parser
@@ -206,11 +208,12 @@ impl ExtendedQueryHandler for DfSessionService {
206208
_max_rows: usize,
207209
) -> PgWireResult<Response>
208210
where
209-
C: ClientInfo + Unpin + Send + Sync,
211+
C: ClientInfo + futures::Sink<PgWireBackendMessage> + Unpin + Send + Sync,
212+
C::Error: std::fmt::Debug,
213+
PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
210214
{
211215
let query = &portal.statement.statement.0;
212-
log::debug!("Received execute extended query: {query}"); // Log for debugging
213-
216+
log::debug!("Received execute extended query: {query}");
214217
// Check query hooks first
215218
if !self.query_hooks.is_empty() {
216219
if let (_, Some((statement, plan))) = &portal.statement.statement {
@@ -243,13 +246,12 @@ impl ExtendedQueryHandler for DfSessionService {
243246
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
244247

245248
let param_values =
246-
df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
249+
df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;
247250

248251
let plan = plan
249252
.clone()
250253
.replace_params_with_values(&param_values)
251-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
252-
// &param_values
254+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
253255
let optimised = self
254256
.session_context
255257
.state()
@@ -345,8 +347,7 @@ impl QueryParser for Parser {
345347
where
346348
C: ClientInfo + Unpin + Send + Sync,
347349
{
348-
log::debug!("Received parse extended query: {sql}"); // Log for debugging
349-
350+
log::debug!("Received parse extended query: {sql}");
350351
let mut statements = self
351352
.sql_parser
352353
.parse(sql)
@@ -384,7 +385,6 @@ impl QueryParser for Parser {
384385

385386
let mut param_types = Vec::with_capacity(params.len());
386387
for param_type in ordered_param_types(&params).iter() {
387-
// Fixed: Use &params
388388
if let Some(datatype) = param_type {
389389
let pgtype = into_pg_type(datatype)?;
390390
param_types.push(pgtype);
@@ -434,6 +434,8 @@ mod tests {
434434
use super::*;
435435
use crate::testing::MockClient;
436436

437+
use crate::hooks::HookClient;
438+
437439
struct TestHook;
438440

439441
#[async_trait]
@@ -442,7 +444,7 @@ mod tests {
442444
&self,
443445
statement: &sqlparser::ast::Statement,
444446
_ctx: &SessionContext,
445-
_client: &mut (dyn ClientInfo + Sync + Send),
447+
_client: &mut dyn HookClient,
446448
) -> Option<PgWireResult<Response>> {
447449
if statement.to_string().contains("magic") {
448450
Some(Ok(Response::EmptyQuery))
@@ -466,7 +468,7 @@ mod tests {
466468
_logical_plan: &LogicalPlan,
467469
_params: &ParamValues,
468470
_session_context: &SessionContext,
469-
_client: &mut (dyn ClientInfo + Send + Sync),
471+
_client: &mut dyn HookClient,
470472
) -> Option<PgWireResult<Response>> {
471473
None
472474
}
@@ -523,4 +525,84 @@ mod tests {
523525
assert!(matches!(results[2], Response::EmptyQuery));
524526
assert!(matches!(results[3], Response::Query(_)));
525527
}
528+
529+
#[tokio::test]
530+
async fn test_set_sends_parameter_status_via_sink() {
531+
use pgwire::messages::PgWireBackendMessage;
532+
533+
let service = crate::testing::setup_handlers();
534+
let mut client = MockClient::new();
535+
536+
let test_cases = vec![
537+
("SET datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
538+
(
539+
"SET intervalstyle = 'postgres'",
540+
"IntervalStyle",
541+
"postgres",
542+
),
543+
("SET bytea_output = 'hex'", "bytea_output", "hex"),
544+
(
545+
"SET application_name = 'myapp'",
546+
"application_name",
547+
"myapp",
548+
),
549+
("SET search_path = 'public'", "search_path", "public"),
550+
("SET extra_float_digits = '2'", "extra_float_digits", "2"),
551+
(
552+
"SET TIME ZONE 'America/New_York'",
553+
"TimeZone",
554+
"America/New_York",
555+
),
556+
];
557+
558+
for (sql, expected_key, expected_value) in test_cases {
559+
client.sent_messages.clear();
560+
561+
let responses =
562+
<DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, sql)
563+
.await
564+
.unwrap();
565+
566+
assert!(
567+
matches!(responses[0], Response::Execution(_)),
568+
"Expected SET tag for {sql}"
569+
);
570+
571+
let ps_msgs: Vec<_> = client
572+
.sent_messages()
573+
.iter()
574+
.filter_map(|m| match m {
575+
PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
576+
_ => None,
577+
})
578+
.collect();
579+
580+
assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
581+
assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
582+
assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
583+
}
584+
}
585+
586+
#[tokio::test]
587+
async fn test_set_statement_timeout_no_parameter_status() {
588+
use pgwire::messages::PgWireBackendMessage;
589+
590+
let service = crate::testing::setup_handlers();
591+
let mut client = MockClient::new();
592+
593+
<DfSessionService as SimpleQueryHandler>::do_query(
594+
&service,
595+
&mut client,
596+
"SET statement_timeout TO '5000ms'",
597+
)
598+
.await
599+
.unwrap();
600+
601+
let has_ps = client
602+
.sent_messages()
603+
.iter()
604+
.any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));
605+
606+
assert!(!has_ps, "statement_timeout should not send ParameterStatus");
607+
}
526608
}

datafusion-postgres/src/hooks/mod.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,28 @@ use datafusion::common::ParamValues;
88
use datafusion::logical_expr::LogicalPlan;
99
use datafusion::prelude::SessionContext;
1010
use datafusion::sql::sqlparser::ast::Statement;
11+
use futures::Sink;
1112
use pgwire::api::results::Response;
1213
use pgwire::api::ClientInfo;
13-
use pgwire::error::PgWireResult;
14+
use pgwire::error::{PgWireError, PgWireResult};
15+
use pgwire::messages::PgWireBackendMessage;
16+
17+
#[async_trait]
18+
pub trait HookClient: ClientInfo + Send + Sync {
19+
async fn send_message(&mut self, item: PgWireBackendMessage) -> PgWireResult<()>;
20+
}
21+
22+
#[async_trait]
23+
impl<S> HookClient for S
24+
where
25+
S: ClientInfo + Sink<PgWireBackendMessage> + Send + Sync + Unpin,
26+
PgWireError: From<<S as Sink<PgWireBackendMessage>>::Error>,
27+
{
28+
async fn send_message(&mut self, item: PgWireBackendMessage) -> PgWireResult<()> {
29+
use futures::SinkExt;
30+
self.send(item).await.map_err(PgWireError::from)
31+
}
32+
}
1433

1534
#[async_trait]
1635
pub trait QueryHook: Send + Sync {
@@ -19,7 +38,7 @@ pub trait QueryHook: Send + Sync {
1938
&self,
2039
statement: &Statement,
2140
session_context: &SessionContext,
22-
client: &mut (dyn ClientInfo + Send + Sync),
41+
client: &mut dyn HookClient,
2342
) -> Option<PgWireResult<Response>>;
2443

2544
/// called at extended query parse phase, for generating `LogicalPlan`from statement
@@ -37,6 +56,6 @@ pub trait QueryHook: Send + Sync {
3756
logical_plan: &LogicalPlan,
3857
params: &ParamValues,
3958
session_context: &SessionContext,
40-
client: &mut (dyn ClientInfo + Send + Sync),
59+
client: &mut dyn HookClient,
4160
) -> Option<PgWireResult<Response>>;
4261
}

datafusion-postgres/src/hooks/permissions.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ use datafusion::common::ParamValues;
55
use datafusion::logical_expr::LogicalPlan;
66
use datafusion::prelude::SessionContext;
77
use datafusion::sql::sqlparser::ast::Statement;
8-
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
98
use pgwire::api::results::Response;
109
use pgwire::api::ClientInfo;
1110
use pgwire::error::{PgWireError, PgWireResult};
1211

1312
use crate::auth::AuthManager;
13+
use crate::hooks::HookClient;
1414
use crate::QueryHook;
1515

16+
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
17+
1618
#[derive(Debug)]
1719
pub struct PermissionsHook {
1820
auth_manager: Arc<AuthManager>,
@@ -96,7 +98,7 @@ impl QueryHook for PermissionsHook {
9698
&self,
9799
statement: &Statement,
98100
_session_context: &SessionContext,
99-
client: &mut (dyn ClientInfo + Send + Sync),
101+
client: &mut dyn HookClient,
100102
) -> Option<PgWireResult<Response>> {
101103
if Self::should_skip_permission_check(statement) {
102104
return None;
@@ -125,7 +127,7 @@ impl QueryHook for PermissionsHook {
125127
_logical_plan: &LogicalPlan,
126128
_params: &ParamValues,
127129
_session_context: &SessionContext,
128-
client: &mut (dyn ClientInfo + Send + Sync),
130+
client: &mut dyn HookClient,
129131
) -> Option<PgWireResult<Response>> {
130132
if Self::should_skip_permission_check(statement) {
131133
return None;

0 commit comments

Comments
 (0)