Skip to content

Commit 827ddcd

Browse files
Streamlined derivation of new Dialect objects
1 parent 845e213 commit 827ddcd

17 files changed

Lines changed: 916 additions & 268 deletions

Cargo.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ std = []
4242
recursive-protection = ["std", "recursive"]
4343
# Enable JSON output in the `cli` example:
4444
json_example = ["serde_json", "serde"]
45+
# Enable derive macros for custom dialect creation
46+
derive-dialect = ["sqlparser_derive"]
4547
visitor = ["sqlparser_derive"]
4648

4749
[dependencies]
@@ -61,6 +63,10 @@ simple_logger = "5.0"
6163
matches = "0.1"
6264
pretty_assertions = "1"
6365

66+
[[test]]
67+
name = "sqlparser_derive_dialect"
68+
required-features = ["derive-dialect"]
69+
6470
[package.metadata.docs.rs]
6571
# Document these features on docs.rs
66-
features = ["serde", "visitor"]
72+
features = ["serde", "visitor", "derive-dialect"]

derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ edition = "2021"
3636
proc-macro = true
3737

3838
[dependencies]
39-
syn = { version = "2.0", default-features = false, features = ["printing", "parsing", "derive", "proc-macro"] }
39+
syn = { version = "2.0", default-features = false, features = ["full", "printing", "parsing", "derive", "proc-macro", "clone-impls"] }
4040
proc-macro2 = "1.0"
4141
quote = "1.0"

derive/src/dialect.rs

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Implementation of the `derive_dialect!` macro for creating custom SQL dialects.
19+
20+
use proc_macro2::TokenStream;
21+
use quote::{quote, quote_spanned};
22+
use std::collections::HashSet;
23+
use syn::{
24+
braced,
25+
parse::{Parse, ParseStream},
26+
Error, File, FnArg, Ident, Item, LitBool, LitChar, Pat, ReturnType, Signature, Token,
27+
TraitItem, Type,
28+
};
29+
30+
/// Override value types supported by the macro
31+
pub(crate) enum Override {
32+
Bool(LitBool),
33+
Char(LitChar),
34+
None,
35+
}
36+
37+
/// Parsed input for the `derive_dialect!` macro
38+
pub(crate) struct DeriveDialectInput {
39+
pub name: Ident,
40+
pub base: Type,
41+
pub preserve_type_id: bool,
42+
pub overrides: Vec<(Ident, Override)>,
43+
}
44+
45+
/// Information about a Dialect trait method
46+
struct DialectMethod {
47+
name: Ident,
48+
signature: Signature,
49+
}
50+
51+
impl Parse for DeriveDialectInput {
52+
fn parse(input: ParseStream) -> syn::Result<Self> {
53+
let name: Ident = input.parse()?;
54+
input.parse::<Token![,]>()?;
55+
let base: Type = input.parse()?;
56+
57+
let mut preserve_type_id = false;
58+
let mut overrides = Vec::new();
59+
60+
while input.peek(Token![,]) {
61+
input.parse::<Token![,]>()?;
62+
if input.is_empty() {
63+
break;
64+
}
65+
if input.peek(Ident) {
66+
let ident: Ident = input.parse()?;
67+
match ident.to_string().as_str() {
68+
"preserve_type_id" => {
69+
input.parse::<Token![=]>()?;
70+
preserve_type_id = input.parse::<LitBool>()?.value();
71+
}
72+
"overrides" => {
73+
input.parse::<Token![=]>()?;
74+
let content;
75+
braced!(content in input);
76+
while !content.is_empty() {
77+
let key: Ident = content.parse()?;
78+
content.parse::<Token![=]>()?;
79+
let value = if content.peek(LitBool) {
80+
Override::Bool(content.parse()?)
81+
} else if content.peek(LitChar) {
82+
Override::Char(content.parse()?)
83+
} else if content.peek(Ident) {
84+
let ident: Ident = content.parse()?;
85+
if ident == "None" {
86+
Override::None
87+
} else {
88+
return Err(Error::new(
89+
ident.span(),
90+
format!("Expected `true`, `false`, a char, or `None`, found `{ident}`"),
91+
));
92+
}
93+
} else {
94+
return Err(
95+
content.error("Expected `true`, `false`, a char, or `None`")
96+
);
97+
};
98+
overrides.push((key, value));
99+
if content.peek(Token![,]) {
100+
content.parse::<Token![,]>()?;
101+
}
102+
}
103+
}
104+
other => {
105+
return Err(Error::new(ident.span(), format!(
106+
"Unknown argument `{other}`. Expected `preserve_type_id` or `overrides`."
107+
)));
108+
}
109+
}
110+
}
111+
}
112+
Ok(DeriveDialectInput {
113+
name,
114+
base,
115+
preserve_type_id,
116+
overrides,
117+
})
118+
}
119+
}
120+
121+
/// Entry point for the `derive_dialect!` procedural macro
122+
pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStream {
123+
let err = |msg: String| {
124+
Error::new(proc_macro2::Span::call_site(), msg)
125+
.to_compile_error()
126+
.into()
127+
};
128+
129+
let source = match read_dialect_mod_file() {
130+
Ok(s) => s,
131+
Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")),
132+
};
133+
let file: File = match syn::parse_str(&source) {
134+
Ok(f) => f,
135+
Err(e) => return err(format!("Failed to parse source: {e}")),
136+
};
137+
let methods = match extract_dialect_methods(&file) {
138+
Ok(m) => m,
139+
Err(e) => return e.to_compile_error().into(),
140+
};
141+
142+
// Validate overrides
143+
let bool_names: HashSet<_> = methods
144+
.iter()
145+
.filter(|m| is_bool_method(&m.signature))
146+
.map(|m| m.name.to_string())
147+
.collect();
148+
for (key, value) in &input.overrides {
149+
let key_str = key.to_string();
150+
let err = |msg| Error::new(key.span(), msg).to_compile_error().into();
151+
match value {
152+
Override::Bool(_) if !bool_names.contains(&key_str) => {
153+
return err(format!("Unknown boolean method `{key_str}`"));
154+
}
155+
Override::Char(_) | Override::None if key_str != "identifier_quote_style" => {
156+
return err(format!(
157+
"Char/None only valid for `identifier_quote_style`, not `{key_str}`"
158+
));
159+
}
160+
_ => {}
161+
}
162+
}
163+
generate_derived_dialect(&input, &methods).into()
164+
}
165+
166+
/// Generate the complete derived dialect implementation
167+
fn generate_derived_dialect(input: &DeriveDialectInput, methods: &[DialectMethod]) -> TokenStream {
168+
let name = &input.name;
169+
let base = &input.base;
170+
171+
// Helper to find an override by method name
172+
let find_override = |method_name: &str| {
173+
input
174+
.overrides
175+
.iter()
176+
.find(|(k, _)| k == method_name)
177+
.map(|(_, v)| v)
178+
};
179+
180+
// Helper to generate delegation to base dialect
181+
let delegate = |method: &DialectMethod| {
182+
let sig = &method.signature;
183+
let method_name = &method.name;
184+
let params = extract_param_names(sig);
185+
quote_spanned! { method_name.span() => #sig { self.dialect.#method_name(#(#params),*) } }
186+
};
187+
188+
// Generate the struct
189+
let struct_def = quote_spanned! { name.span() =>
190+
#[derive(Debug)]
191+
pub struct #name {
192+
dialect: #base,
193+
}
194+
impl Default for #name {
195+
fn default() -> Self {
196+
Self { dialect: <#base>::default() }
197+
}
198+
}
199+
impl #name {
200+
pub fn new() -> Self { Self::default() }
201+
}
202+
};
203+
204+
// Generate TypeId method body
205+
let type_id_body = if input.preserve_type_id {
206+
quote! { Dialect::dialect(&self.dialect) }
207+
} else {
208+
quote! { ::core::any::TypeId::of::<#name>() }
209+
};
210+
211+
// Generate method implementations
212+
let method_impls = methods.iter().map(|method| {
213+
let method_name = &method.name;
214+
match find_override(&method_name.to_string()) {
215+
Some(Override::Bool(value)) => {
216+
quote_spanned! { method_name.span() => fn #method_name(&self) -> bool { #value } }
217+
}
218+
Some(Override::Char(c)) => {
219+
quote_spanned! { method_name.span() =>
220+
fn identifier_quote_style(&self, _: &str) -> Option<char> { Some(#c) }
221+
}
222+
}
223+
Some(Override::None) => {
224+
quote_spanned! { method_name.span() =>
225+
fn identifier_quote_style(&self, _: &str) -> Option<char> { None }
226+
}
227+
}
228+
None => delegate(method),
229+
}
230+
});
231+
232+
// Wrap impl in a const block with scoped imports so types resolve without qualification
233+
quote! {
234+
#struct_def
235+
const _: () = {
236+
use ::core::iter::Peekable;
237+
use ::core::str::Chars;
238+
use sqlparser::ast::{ColumnOption, Expr, GranteesType, Ident, ObjectNamePart, Statement};
239+
use sqlparser::dialect::{Dialect, Precedence};
240+
use sqlparser::keywords::Keyword;
241+
use sqlparser::parser::{Parser, ParserError};
242+
243+
impl Dialect for #name {
244+
fn dialect(&self) -> ::core::any::TypeId { #type_id_body }
245+
#(#method_impls)*
246+
}
247+
};
248+
}
249+
}
250+
251+
/// Extract parameter names from a method signature (excluding self)
252+
fn extract_param_names(sig: &Signature) -> Vec<&Ident> {
253+
sig.inputs
254+
.iter()
255+
.filter_map(|arg| match arg {
256+
FnArg::Typed(pt) => match pt.pat.as_ref() {
257+
Pat::Ident(pi) => Some(&pi.ident),
258+
_ => None,
259+
},
260+
_ => None,
261+
})
262+
.collect()
263+
}
264+
265+
/// Read the dialect/mod.rs file that contains the Dialect trait.
266+
fn read_dialect_mod_file() -> Result<String, String> {
267+
let manifest_dir =
268+
std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set")?;
269+
let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs");
270+
std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}: {e}", path.display()))
271+
}
272+
273+
/// Extract all methods from the Dialect trait (excluding `dialect` for TypeId)
274+
fn extract_dialect_methods(file: &File) -> Result<Vec<DialectMethod>, Error> {
275+
let dialect_trait = file
276+
.items
277+
.iter()
278+
.find_map(|item| match item {
279+
Item::Trait(t) if t.ident == "Dialect" => Some(t),
280+
_ => None,
281+
})
282+
.ok_or_else(|| Error::new(proc_macro2::Span::call_site(), "Dialect trait not found"))?;
283+
284+
let mut methods: Vec<_> = dialect_trait
285+
.items
286+
.iter()
287+
.filter_map(|item| match item {
288+
TraitItem::Fn(m) if m.sig.ident != "dialect" => Some(DialectMethod {
289+
name: m.sig.ident.clone(),
290+
signature: m.sig.clone(),
291+
}),
292+
_ => None,
293+
})
294+
.collect();
295+
methods.sort_by_key(|m| m.name.to_string());
296+
Ok(methods)
297+
}
298+
299+
/// Check if a method signature is `fn name(&self) -> bool`
300+
fn is_bool_method(sig: &Signature) -> bool {
301+
sig.inputs.len() == 1
302+
&& matches!(
303+
sig.inputs.first(),
304+
Some(FnArg::Receiver(r)) if r.reference.is_some() && r.mutability.is_none()
305+
)
306+
&& matches!(
307+
&sig.output,
308+
ReturnType::Type(_, ty) if matches!(ty.as_ref(), Type::Path(p) if p.path.is_ident("bool"))
309+
)
310+
}

0 commit comments

Comments
 (0)