Skip to content

Commit bd071be

Browse files
authored
feat: add custom_string_literal_override to unparser Dialect trait (#20590)
## Which issue does this PR close? - Closes #. ## Rationale for this change When unparsing queries targeting databases like MSSQL, non-ASCII string literals need special handling. MSSQL requires the `N'...'` (national string literal) prefix for strings containing Unicode characters. Currently the unparser always emits single-quoted strings with no way for dialects to customize this behavior. ## What changes are included in this PR? - Add a new `custom_string_literal_override` method to the `Dialect` trait with a default implementation returning `None` (no override). - Consolidate the `Utf8`, `Utf8View`, and `LargeUtf8` match arms in `scalar_value_to_sql` and route them through the new dialect hook. ## Are these changes tested? Yes. A test-only `MsSqlDialect` is defined in the test module to verify: - ASCII strings produce standard single-quoted literals (no `N` prefix) - Non-ASCII strings produce national string literals (`N'...`') - The default dialect is unaffected (no `N` prefix regardless of content) It's used by Wren AI in production for a while: Canner#8 ## Are there any user-facing changes? Yes. The `Dialect` trait gains a new method `custom_string_literal_override`. This is a non-breaking change since the method has a default implementation. Dialect implementors can override it to customize string literal unparsing.
1 parent 4166a6d commit bd071be

2 files changed

Lines changed: 84 additions & 11 deletions

File tree

datafusion/sql/src/unparser/dialect.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,17 @@ pub trait Dialect: Send + Sync {
248248
fn supports_empty_select_list(&self) -> bool {
249249
false
250250
}
251+
252+
/// Override the default string literal unparsing.
253+
///
254+
/// Returns `Some(ast::Expr)` to replace the default single-quoted string,
255+
/// or `None` to use the default behavior.
256+
///
257+
/// For example, MSSQL requires non-ASCII strings to use national string
258+
/// literal syntax (`N'datafusion資料融合'`).
259+
fn string_literal_to_sql(&self, _s: &str) -> Option<ast::Expr> {
260+
None
261+
}
251262
}
252263

253264
/// `IntervalStyle` to use for unparsing

datafusion/sql/src/unparser/expr.rs

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,18 +1294,17 @@ impl Unparser<'_> {
12941294
Ok(ast::Expr::value(ast::Value::Number(ui.to_string(), false)))
12951295
}
12961296
ScalarValue::UInt64(None) => Ok(ast::Expr::value(ast::Value::Null)),
1297-
ScalarValue::Utf8(Some(str)) => {
1298-
Ok(ast::Expr::value(SingleQuotedString(str.to_string())))
1299-
}
1300-
ScalarValue::Utf8(None) => Ok(ast::Expr::value(ast::Value::Null)),
1301-
ScalarValue::Utf8View(Some(str)) => {
1302-
Ok(ast::Expr::value(SingleQuotedString(str.to_string())))
1303-
}
1304-
ScalarValue::Utf8View(None) => Ok(ast::Expr::value(ast::Value::Null)),
1305-
ScalarValue::LargeUtf8(Some(str)) => {
1297+
ScalarValue::Utf8(Some(str))
1298+
| ScalarValue::Utf8View(Some(str))
1299+
| ScalarValue::LargeUtf8(Some(str)) => {
1300+
if let Some(expr) = self.dialect.string_literal_to_sql(str) {
1301+
return Ok(expr);
1302+
}
13061303
Ok(ast::Expr::value(SingleQuotedString(str.to_string())))
13071304
}
1308-
ScalarValue::LargeUtf8(None) => Ok(ast::Expr::value(ast::Value::Null)),
1305+
ScalarValue::Utf8(None)
1306+
| ScalarValue::Utf8View(None)
1307+
| ScalarValue::LargeUtf8(None) => Ok(ast::Expr::value(ast::Value::Null)),
13091308
ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: {v:?}"),
13101309
ScalarValue::Binary(None) => Ok(ast::Expr::value(ast::Value::Null)),
13111310
ScalarValue::BinaryView(Some(_)) => {
@@ -2397,7 +2396,6 @@ mod tests {
23972396

23982397
let expected = r#"('a' > 4)"#;
23992398
assert_eq!(actual, expected);
2400-
24012399
Ok(())
24022400
}
24032401

@@ -2960,6 +2958,70 @@ mod tests {
29602958
Ok(())
29612959
}
29622960

2961+
#[test]
2962+
fn test_mssql_dialect_national_literal() -> Result<()> {
2963+
struct MsSqlDialect;
2964+
2965+
impl Dialect for MsSqlDialect {
2966+
fn identifier_quote_style(&self, _identifier: &str) -> Option<char> {
2967+
Some('[')
2968+
}
2969+
2970+
fn string_literal_to_sql(&self, s: &str) -> Option<ast::Expr> {
2971+
if !s.is_ascii() {
2972+
Some(ast::Expr::value(ast::Value::NationalStringLiteral(
2973+
s.to_string(),
2974+
)))
2975+
} else {
2976+
None
2977+
}
2978+
}
2979+
}
2980+
2981+
let dialect = MsSqlDialect;
2982+
let unparser = Unparser::new(&dialect);
2983+
2984+
// Get nation string literal for the custom mssql dialect
2985+
for (s, expected) in [
2986+
("national string", "'national string'"),
2987+
("datafusion資料融合", "N'datafusion資料融合'"),
2988+
] {
2989+
let expr = Expr::Literal(ScalarValue::Utf8(Some(s.to_string())), None);
2990+
let ast = unparser.expr_to_sql(&expr)?;
2991+
assert_eq!(ast.to_string(), expected);
2992+
2993+
let expr = Expr::Literal(ScalarValue::Utf8View(Some(s.to_string())), None);
2994+
let ast = unparser.expr_to_sql(&expr)?;
2995+
assert_eq!(ast.to_string(), expected);
2996+
2997+
let expr = Expr::Literal(ScalarValue::LargeUtf8(Some(s.to_string())), None);
2998+
let ast = unparser.expr_to_sql(&expr)?;
2999+
assert_eq!(ast.to_string(), expected);
3000+
}
3001+
3002+
let dialect = DefaultDialect {};
3003+
let unparser = Unparser::new(&dialect);
3004+
3005+
// Get normal string literal for default dialect
3006+
for (s, expected) in [
3007+
("national string", "'national string'"),
3008+
("datafusion資料融合", "'datafusion資料融合'"),
3009+
] {
3010+
let expr = Expr::Literal(ScalarValue::Utf8(Some(s.to_string())), None);
3011+
let ast = unparser.expr_to_sql(&expr)?;
3012+
assert_eq!(ast.to_string(), expected);
3013+
3014+
let expr = Expr::Literal(ScalarValue::Utf8View(Some(s.to_string())), None);
3015+
let ast = unparser.expr_to_sql(&expr)?;
3016+
assert_eq!(ast.to_string(), expected);
3017+
3018+
let expr = Expr::Literal(ScalarValue::LargeUtf8(Some(s.to_string())), None);
3019+
let ast = unparser.expr_to_sql(&expr)?;
3020+
assert_eq!(ast.to_string(), expected);
3021+
}
3022+
Ok(())
3023+
}
3024+
29633025
#[test]
29643026
fn test_cast_value_to_dict_expr() {
29653027
let tests = [(

0 commit comments

Comments
 (0)