Skip to content

Commit 22e08bc

Browse files
unknowntpoclaude
andauthored
feat: support Spark-compatible string_to_map function (#20120)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #15914 - Related comet issue: apache/datafusion-comet#3168 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> - Apache Spark's `str_to_map` creates a map by splitting a string into key-value pairs using delimiters. - This function is used in Spark SQL and needed for DataFusion-Comet compatibility. - `LAST_WIN` policy of handling duplicate key will be implemented in next PR. - Reference: https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> - Add Spark-compatible `str_to_map` function in `datafusion-spark` crate - Function signature: `str_to_map(text, [pairDelim], [keyValueDelim]) -> Map<String, String>` - `text`: The input string - `pairDelim`: Delimiter between key-value pairs (default: `,`) - `keyValueDelim`: Delimiter between key and value (default: `:`) - Located in `function/map/` module (returns Map type) ### Examples ```sql SELECT str_to_map('a:1,b:2,c:3'); -- {a: 1, b: 2, c: 3} SELECT str_to_map('a=1;b=2', ';', '='); -- {a: 1, b: 2} SELECT str_to_map('key:value'); -- {key: value} ``` ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> - sqllogictest: `test_files/spark/map/string_to_map.slt`, test cases derived from [Spark test("StringToMap")](https://github.com/apache/spark/blob/v4.0.0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala#L525-L618 ): ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> Yes. <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 6a2bfd4 commit 22e08bc

3 files changed

Lines changed: 389 additions & 1 deletion

File tree

datafusion/spark/src/function/map/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
pub mod map_from_arrays;
1919
pub mod map_from_entries;
20+
pub mod str_to_map;
2021
mod utils;
2122

2223
use datafusion_expr::ScalarUDF;
@@ -25,6 +26,7 @@ use std::sync::Arc;
2526

2627
make_udf_function!(map_from_arrays::MapFromArrays, map_from_arrays);
2728
make_udf_function!(map_from_entries::MapFromEntries, map_from_entries);
29+
make_udf_function!(str_to_map::SparkStrToMap, str_to_map);
2830

2931
pub mod expr_fn {
3032
use datafusion_functions::export_functions;
@@ -40,8 +42,14 @@ pub mod expr_fn {
4042
"Creates a map from array<struct<key, value>>.",
4143
arg1
4244
));
45+
46+
export_functions!((
47+
str_to_map,
48+
"Creates a map after splitting the text into key/value pairs using delimiters.",
49+
text pair_delim key_value_delim
50+
));
4351
}
4452

4553
pub fn functions() -> Vec<Arc<ScalarUDF>> {
46-
vec![map_from_arrays(), map_from_entries()]
54+
vec![map_from_arrays(), map_from_entries(), str_to_map()]
4755
}
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::collections::HashSet;
20+
use std::sync::Arc;
21+
22+
use arrow::array::{
23+
Array, ArrayRef, MapBuilder, MapFieldNames, StringArrayType, StringBuilder,
24+
};
25+
use arrow::buffer::NullBuffer;
26+
use arrow::datatypes::{DataType, Field, FieldRef};
27+
use datafusion_common::cast::{
28+
as_large_string_array, as_string_array, as_string_view_array,
29+
};
30+
use datafusion_common::{Result, exec_err, internal_err};
31+
use datafusion_expr::{
32+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33+
TypeSignature, Volatility,
34+
};
35+
36+
use crate::function::map::utils::map_type_from_key_value_types;
37+
38+
const DEFAULT_PAIR_DELIM: &str = ",";
39+
const DEFAULT_KV_DELIM: &str = ":";
40+
41+
/// Spark-compatible `str_to_map` expression
42+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map>
43+
///
44+
/// Creates a map from a string by splitting on delimiters.
45+
/// str_to_map(text[, pairDelim[, keyValueDelim]]) -> Map<String, String>
46+
///
47+
/// - text: The input string
48+
/// - pairDelim: Delimiter between key-value pairs (default: ',')
49+
/// - keyValueDelim: Delimiter between key and value (default: ':')
50+
///
51+
/// # Duplicate Key Handling
52+
/// Uses EXCEPTION behavior (Spark 3.0+ default): errors on duplicate keys.
53+
/// See `spark.sql.mapKeyDedupPolicy`:
54+
/// <https://github.com/apache/spark/blob/v4.0.0/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4502-L4511>
55+
///
56+
/// TODO: Support configurable `spark.sql.mapKeyDedupPolicy` (LAST_WIN) in a follow-up PR.
57+
#[derive(Debug, PartialEq, Eq, Hash)]
58+
pub struct SparkStrToMap {
59+
signature: Signature,
60+
}
61+
62+
impl Default for SparkStrToMap {
63+
fn default() -> Self {
64+
Self::new()
65+
}
66+
}
67+
68+
impl SparkStrToMap {
69+
pub fn new() -> Self {
70+
Self {
71+
signature: Signature::one_of(
72+
vec![
73+
// str_to_map(text)
74+
TypeSignature::String(1),
75+
// str_to_map(text, pairDelim)
76+
TypeSignature::String(2),
77+
// str_to_map(text, pairDelim, keyValueDelim)
78+
TypeSignature::String(3),
79+
],
80+
Volatility::Immutable,
81+
),
82+
}
83+
}
84+
}
85+
86+
impl ScalarUDFImpl for SparkStrToMap {
87+
fn as_any(&self) -> &dyn Any {
88+
self
89+
}
90+
91+
fn name(&self) -> &str {
92+
"str_to_map"
93+
}
94+
95+
fn signature(&self) -> &Signature {
96+
&self.signature
97+
}
98+
99+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
100+
internal_err!("return_field_from_args should be used instead")
101+
}
102+
103+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
104+
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
105+
let map_type = map_type_from_key_value_types(&DataType::Utf8, &DataType::Utf8);
106+
Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
107+
}
108+
109+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110+
let arrays: Vec<ArrayRef> = ColumnarValue::values_to_arrays(&args.args)?;
111+
let result = str_to_map_inner(&arrays)?;
112+
Ok(ColumnarValue::Array(result))
113+
}
114+
}
115+
116+
fn str_to_map_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
117+
match args.len() {
118+
1 => match args[0].data_type() {
119+
DataType::Utf8 => str_to_map_impl(as_string_array(&args[0])?, None, None),
120+
DataType::LargeUtf8 => {
121+
str_to_map_impl(as_large_string_array(&args[0])?, None, None)
122+
}
123+
DataType::Utf8View => {
124+
str_to_map_impl(as_string_view_array(&args[0])?, None, None)
125+
}
126+
other => exec_err!(
127+
"Unsupported data type {other:?} for str_to_map, \
128+
expected Utf8, LargeUtf8, or Utf8View"
129+
),
130+
},
131+
2 => match (args[0].data_type(), args[1].data_type()) {
132+
(DataType::Utf8, DataType::Utf8) => str_to_map_impl(
133+
as_string_array(&args[0])?,
134+
Some(as_string_array(&args[1])?),
135+
None,
136+
),
137+
(DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl(
138+
as_large_string_array(&args[0])?,
139+
Some(as_large_string_array(&args[1])?),
140+
None,
141+
),
142+
(DataType::Utf8View, DataType::Utf8View) => str_to_map_impl(
143+
as_string_view_array(&args[0])?,
144+
Some(as_string_view_array(&args[1])?),
145+
None,
146+
),
147+
(t1, t2) => exec_err!(
148+
"Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \
149+
expected matching Utf8, LargeUtf8, or Utf8View"
150+
),
151+
},
152+
3 => match (
153+
args[0].data_type(),
154+
args[1].data_type(),
155+
args[2].data_type(),
156+
) {
157+
(DataType::Utf8, DataType::Utf8, DataType::Utf8) => str_to_map_impl(
158+
as_string_array(&args[0])?,
159+
Some(as_string_array(&args[1])?),
160+
Some(as_string_array(&args[2])?),
161+
),
162+
(DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
163+
str_to_map_impl(
164+
as_large_string_array(&args[0])?,
165+
Some(as_large_string_array(&args[1])?),
166+
Some(as_large_string_array(&args[2])?),
167+
)
168+
}
169+
(DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
170+
str_to_map_impl(
171+
as_string_view_array(&args[0])?,
172+
Some(as_string_view_array(&args[1])?),
173+
Some(as_string_view_array(&args[2])?),
174+
)
175+
}
176+
(t1, t2, t3) => exec_err!(
177+
"Unsupported data types ({t1:?}, {t2:?}, {t3:?}) for str_to_map, \
178+
expected matching Utf8, LargeUtf8, or Utf8View"
179+
),
180+
},
181+
n => exec_err!("str_to_map expects 1-3 arguments, got {n}"),
182+
}
183+
}
184+
185+
fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>(
186+
text_array: V,
187+
pair_delim_array: Option<V>,
188+
kv_delim_array: Option<V>,
189+
) -> Result<ArrayRef> {
190+
let num_rows = text_array.len();
191+
192+
// Precompute combined null buffer from all input arrays.
193+
// NullBuffer::union performs a bitmap-level AND, which is more efficient
194+
// than checking per-row nullability inline.
195+
let text_nulls = text_array.nulls().cloned();
196+
let pair_nulls = pair_delim_array.and_then(|a| a.nulls().cloned());
197+
let kv_nulls = kv_delim_array.and_then(|a| a.nulls().cloned());
198+
let combined_nulls = [text_nulls.as_ref(), pair_nulls.as_ref(), kv_nulls.as_ref()]
199+
.into_iter()
200+
.fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
201+
202+
// Use field names matching map_type_from_key_value_types: "key" and "value"
203+
let field_names = MapFieldNames {
204+
entry: "entries".to_string(),
205+
key: "key".to_string(),
206+
value: "value".to_string(),
207+
};
208+
let mut map_builder = MapBuilder::new(
209+
Some(field_names),
210+
StringBuilder::new(),
211+
StringBuilder::new(),
212+
);
213+
214+
let mut seen_keys = HashSet::new();
215+
for row_idx in 0..num_rows {
216+
if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) {
217+
map_builder.append(false)?;
218+
continue;
219+
}
220+
221+
// Per-row delimiter extraction
222+
let pair_delim =
223+
pair_delim_array.map_or(DEFAULT_PAIR_DELIM, |a| a.value(row_idx));
224+
let kv_delim = kv_delim_array.map_or(DEFAULT_KV_DELIM, |a| a.value(row_idx));
225+
226+
let text = text_array.value(row_idx);
227+
if text.is_empty() {
228+
// Empty string -> map with empty key and NULL value (Spark behavior)
229+
map_builder.keys().append_value("");
230+
map_builder.values().append_null();
231+
map_builder.append(true)?;
232+
continue;
233+
}
234+
235+
seen_keys.clear();
236+
for pair in text.split(pair_delim) {
237+
if pair.is_empty() {
238+
continue;
239+
}
240+
241+
let mut kv_iter = pair.splitn(2, kv_delim);
242+
let key = kv_iter.next().unwrap_or("");
243+
let value = kv_iter.next();
244+
245+
// TODO: Support LAST_WIN policy via spark.sql.mapKeyDedupPolicy config
246+
// EXCEPTION policy: error on duplicate keys (Spark 3.0+ default)
247+
if !seen_keys.insert(key) {
248+
return exec_err!(
249+
"Duplicate map key '{key}' was found, please check the input data. \
250+
If you want to remove the duplicated keys, you can set \
251+
spark.sql.mapKeyDedupPolicy to \"LAST_WIN\" so that the key \
252+
inserted at last takes precedence."
253+
);
254+
}
255+
256+
map_builder.keys().append_value(key);
257+
match value {
258+
Some(v) => map_builder.values().append_value(v),
259+
None => map_builder.values().append_null(),
260+
}
261+
}
262+
map_builder.append(true)?;
263+
}
264+
265+
Ok(Arc::new(map_builder.finish()))
266+
}

0 commit comments

Comments
 (0)