Skip to content

Commit 75428f1

Browse files
authored
fix: Avoid integer overflow in split_part() (#20198)
Along the way, improve the docs slightly. ## Rationale for this change Evaluating `SELECT SPLIT_PART('', '', -9223372036854775808);` yields (in a debug build): ``` thread 'main' (41405991) panicked at datafusion/functions/src/string/split_part.rs:236:47: attempt to negate with overflow ``` <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## Are these changes tested? Yes, added unit test. <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 15f38aa commit 75428f1

2 files changed

Lines changed: 20 additions & 3 deletions

File tree

datafusion/functions/src/string/split_part.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ use std::sync::Arc;
4848
```"#,
4949
standard_argument(name = "str", prefix = "String"),
5050
argument(name = "delimiter", description = "String or character to split on."),
51-
argument(name = "pos", description = "Position of the part to return.")
51+
argument(
52+
name = "pos",
53+
description = "Position of the part to return (counting from 1). Negative values count backward from the end of the string."
54+
)
5255
)]
5356
#[derive(Debug, PartialEq, Eq, Hash)]
5457
pub struct SplitPartFunc {
@@ -233,7 +236,7 @@ where
233236
std::cmp::Ordering::Less => {
234237
// Negative index: use rsplit().nth() to efficiently get from the end
235238
// rsplit iterates in reverse, so -1 means first from rsplit (index 0)
236-
let idx: usize = (-n - 1).try_into().map_err(|_| {
239+
let idx: usize = (n.unsigned_abs() - 1).try_into().map_err(|_| {
237240
exec_datafusion_err!(
238241
"split_part index {n} exceeds minimum supported value"
239242
)
@@ -324,6 +327,20 @@ mod tests {
324327
Utf8,
325328
StringArray
326329
);
330+
test_function!(
331+
SplitPartFunc::new(),
332+
vec![
333+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
334+
"abc~@~def~@~ghi"
335+
)))),
336+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
337+
ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))),
338+
],
339+
Ok(Some("")),
340+
&str,
341+
Utf8,
342+
StringArray
343+
);
327344

328345
Ok(())
329346
}

docs/source/user-guide/sql/scalar_functions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1891,7 +1891,7 @@ split_part(str, delimiter, pos)
18911891

18921892
- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators.
18931893
- **delimiter**: String or character to split on.
1894-
- **pos**: Position of the part to return.
1894+
- **pos**: Position of the part to return (counting from 1). Negative values count backward from the end of the string.
18951895

18961896
#### Example
18971897

0 commit comments

Comments
 (0)