diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index 663e7928bfd95..a0c3784dbeee5 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -18,15 +18,20 @@ use std::hint::black_box; use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::array::{ArrayRef, Int64Array, StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use criterion::{Criterion, criterion_group, criterion_main}; -use datafusion_common::config::ConfigOptions; +use datafusion_common::{ScalarValue, config::ConfigOptions}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode::substr_index; use rand::Rng; +use rand::SeedableRng; use rand::distr::{Alphanumeric, Uniform}; use rand::prelude::Distribution; +use rand::rngs::StdRng; + +const ARRAY_DATA_SEED: u64 = 0x5EED_AAAA; +const SCALAR_DATA_SEED: u64 = 0x5EED_BBBB; struct Filter { dist: Dist, @@ -48,28 +53,72 @@ where } } -fn data( +#[derive(Clone, Copy)] +enum StringRep { + Utf8, + Utf8View, +} + +impl StringRep { + fn name(self) -> &'static str { + match self { + Self::Utf8 => "utf8", + Self::Utf8View => "utf8view", + } + } + + fn data_type(self) -> DataType { + match self { + Self::Utf8 => DataType::Utf8, + Self::Utf8View => DataType::Utf8View, + } + } + + fn array(self, values: &[String]) -> ArrayRef { + match self { + Self::Utf8 => Arc::new(StringArray::from(values.to_vec())) as ArrayRef, + Self::Utf8View => Arc::new( + values + .iter() + .map(|value| Some(value.as_str())) + .collect::(), + ) as ArrayRef, + } + } + + fn scalar(self, value: &str) -> ScalarValue { + match self { + Self::Utf8 => ScalarValue::Utf8(Some(value.to_string())), + Self::Utf8View => ScalarValue::Utf8View(Some(value.to_string())), + } + } +} + +fn random_token(rng: &mut R, len: usize) -> String { + rng.sample_iter(&Alphanumeric) + .take(len) + .map(char::from) + .collect() +} + +fn array_data( batch_size: usize, single_char_delimiter: bool, -) -> (StringArray, StringArray, Int64Array) { - let dist = Filter { - dist: Uniform::new(-4, 5), +) -> (Vec, Vec, Vec) { + let count_dist = Filter { + dist: Uniform::new(-4, 5).expect("valid count distribution"), test: |x: &i64| x != &0, }; - let mut rng = rand::rng(); - let mut strings: Vec = vec![]; - let mut delimiters: Vec = vec![]; - let mut counts: Vec = vec![]; + let mut rng = StdRng::seed_from_u64(ARRAY_DATA_SEED); + let mut strings = Vec::with_capacity(batch_size); + let mut delimiters = Vec::with_capacity(batch_size); + let mut counts = Vec::with_capacity(batch_size); for _ in 0..batch_size { let length = rng.random_range(20..50); - let base: String = (&mut rng) - .sample_iter(&Alphanumeric) - .take(length) - .map(char::from) - .collect(); + let base = random_token(&mut rng, length); - let (string_value, delimiter): (String, String) = if single_char_delimiter { + let (string_value, delimiter) = if single_char_delimiter { let char_idx = rng.random_range(0..base.chars().count()); let delimiter = base.chars().nth(char_idx).unwrap().to_string(); (base, delimiter) @@ -80,7 +129,6 @@ fn data( let delimiter_count = rng.random_range(1..4); let mut result = String::new(); - for i in 0..delimiter_count { result.push_str(&base); if i < delimiter_count - 1 { @@ -90,32 +138,37 @@ fn data( (result, delimiter) }; - let count = rng.sample(dist.dist.unwrap()); - strings.push(string_value); delimiters.push(delimiter); - counts.push(count); + counts.push(count_dist.sample(&mut rng)); + } + + (strings, delimiters, counts) +} + +fn scalar_data(batch_size: usize, delimiter: &str) -> Vec { + let mut rng = StdRng::seed_from_u64(SCALAR_DATA_SEED); + let mut strings = Vec::with_capacity(batch_size); + + for _ in 0..batch_size { + let left_len = rng.random_range(12..24); + let middle_len = rng.random_range(12..24); + let right_len = rng.random_range(12..24); + let left = random_token(&mut rng, left_len); + let middle = random_token(&mut rng, middle_len); + let right = random_token(&mut rng, right_len); + strings.push(format!("{left}{delimiter}{middle}{delimiter}{right}")); } - ( - StringArray::from(strings), - StringArray::from(delimiters), - Int64Array::from(counts), - ) + strings } fn run_benchmark( b: &mut criterion::Bencher, - strings: StringArray, - delimiters: StringArray, - counts: Int64Array, - batch_size: usize, + args: &[ColumnarValue], + return_type: &DataType, + number_rows: usize, ) { - let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); - let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); - let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); - - let args = vec![strings, delimiters, counts]; let arg_fields = args .iter() .enumerate() @@ -129,10 +182,10 @@ fn run_benchmark( black_box( substr_index() .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), + args: args.to_vec(), arg_fields: arg_fields.clone(), - number_rows: batch_size, - return_field: Field::new("f", DataType::Utf8, true).into(), + number_rows, + return_field: Field::new("f", return_type.clone(), true).into(), config_options: Arc::clone(&config_options), }) .expect("substr_index should work on valid values"), @@ -142,22 +195,62 @@ fn run_benchmark( fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("substr_index"); - let batch_sizes = [100, 1000, 10_000]; for batch_size in batch_sizes { - group.bench_function( - format!("substr_index_{batch_size}_single_delimiter"), - |b| { - let (strings, delimiters, counts) = data(batch_size, true); - run_benchmark(b, strings, delimiters, counts, batch_size); - }, - ); - - group.bench_function(format!("substr_index_{batch_size}_long_delimiter"), |b| { - let (strings, delimiters, counts) = data(batch_size, false); - run_benchmark(b, strings, delimiters, counts, batch_size); - }); + for rep in [StringRep::Utf8, StringRep::Utf8View] { + let rep_name = rep.name(); + + group.bench_function( + format!("substr_index_{rep_name}_{batch_size}_array_single_delimiter"), + |b| { + let (strings, delimiters, counts) = array_data(batch_size, true); + let args = vec![ + ColumnarValue::Array(rep.array(&strings)), + ColumnarValue::Array(rep.array(&delimiters)), + ColumnarValue::Array( + Arc::new(Int64Array::from(counts)) as ArrayRef + ), + ]; + run_benchmark(b, &args, &rep.data_type(), batch_size); + }, + ); + + group.bench_function( + format!("substr_index_{rep_name}_{batch_size}_array_long_delimiter"), + |b| { + let (strings, delimiters, counts) = array_data(batch_size, false); + let args = vec![ + ColumnarValue::Array(rep.array(&strings)), + ColumnarValue::Array(rep.array(&delimiters)), + ColumnarValue::Array( + Arc::new(Int64Array::from(counts)) as ArrayRef + ), + ]; + run_benchmark(b, &args, &rep.data_type(), batch_size); + }, + ); + + for (name, delimiter, count) in [ + ("single_delimiter_pos", ".", 1_i64), + ("single_delimiter_neg", ".", -1_i64), + ("long_delimiter_pos", "|||", 1_i64), + ("long_delimiter_neg", "|||", -1_i64), + ] { + group.bench_function( + format!("substr_index_{rep_name}_{batch_size}_scalar_{name}"), + |b| { + let strings = scalar_data(batch_size, delimiter); + let args = vec![ + ColumnarValue::Array(rep.array(&strings)), + ColumnarValue::Scalar(rep.scalar(delimiter)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(count))), + ]; + run_benchmark(b, &args, &rep.data_type(), batch_size); + }, + ); + } + } } group.finish(); diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index c04656282d942..200da45ec95e9 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -18,19 +18,25 @@ use std::sync::Arc; use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, - GenericStringBuilder, OffsetSizeTrait, PrimitiveArray, + Array, ArrayRef, AsArray, ByteView, GenericStringArray, GenericStringBuilder, + OffsetSizeTrait, PrimitiveArray, StringArrayType, StringLikeArrayBuilder, + StringViewArray, make_view, new_null_array, }; -use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::{DataType, Int64Type}; +use arrow_buffer::NullBuffer; -use crate::utils::{make_scalar_function, utf8_to_str_type}; -use datafusion_common::{Result, exec_err, utils::take_function_args}; +use crate::utils::make_scalar_function; +use datafusion_common::{ + Result, ScalarValue, exec_datafusion_err, exec_err, utils::take_function_args, +}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; +use memchr::{memchr_iter, memmem, memrchr_iter}; #[user_doc( doc_section(label = "String Functions"), @@ -101,11 +107,22 @@ impl ScalarUDFImpl for SubstrIndexFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "substr_index") + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(substr_index, vec![])(&args.args) + let ScalarFunctionArgs { args, .. } = args; + + if let ( + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(delim_scalar), + ColumnarValue::Scalar(count_scalar), + ) = (&args[0], &args[1], &args[2]) + { + return substr_index_scalar(string_array, delim_scalar, count_scalar); + } + + make_scalar_function(substr_index, vec![])(&args) } fn aliases(&self) -> &[String] { @@ -130,31 +147,35 @@ fn substr_index(args: &[ArrayRef]) -> Result { let string_array = str.as_string::(); let delimiter_array = delim.as_string::(); let count_array: &PrimitiveArray = count.as_primitive(); - substr_index_general::( + substr_index_general( string_array, delimiter_array, count_array, + GenericStringBuilder::::with_capacity( + string_array.len(), + visible_string_bytes(string_array), + ), ) } DataType::LargeUtf8 => { let string_array = str.as_string::(); let delimiter_array = delim.as_string::(); let count_array: &PrimitiveArray = count.as_primitive(); - substr_index_general::( + substr_index_general( string_array, delimiter_array, count_array, + GenericStringBuilder::::with_capacity( + string_array.len(), + visible_string_bytes(string_array), + ), ) } DataType::Utf8View => { let string_array = str.as_string_view(); let delimiter_array = delim.as_string_view(); let count_array: &PrimitiveArray = count.as_primitive(); - substr_index_general::( - string_array, - delimiter_array, - count_array, - ) + substr_index_view(string_array, delimiter_array, count_array) } other => { exec_err!("Unsupported data type {other:?} for function substr_index") @@ -162,94 +183,441 @@ fn substr_index(args: &[ArrayRef]) -> Result { } } -fn substr_index_general< - 'a, - T: ArrowPrimitiveType, - V: ArrayAccessor, - P: ArrayAccessor, ->( - string_array: V, - delimiter_array: V, - count_array: P, +fn substr_index_scalar( + string_array: &ArrayRef, + delim_scalar: &ScalarValue, + count_scalar: &ScalarValue, +) -> Result { + if string_array.is_empty() { + return Ok(ColumnarValue::Array(new_null_array( + string_array.data_type(), + 0, + ))); + } + + let delimiter = delim_scalar.try_as_str().ok_or_else(|| { + exec_datafusion_err!( + "Unsupported delimiter type {:?} for substr_index", + delim_scalar.data_type() + ) + })?; + + let count = match count_scalar { + ScalarValue::Int64(v) => *v, + other => { + return exec_err!( + "Unsupported count type {:?} for substr_index", + other.data_type() + ); + } + }; + + let (Some(delimiter), Some(count)) = (delimiter, count) else { + return Ok(ColumnarValue::Array(new_null_array( + string_array.data_type(), + string_array.len(), + ))); + }; + + let result = match string_array.data_type() { + DataType::Utf8View => { + substr_index_scalar_view(string_array.as_string_view(), delimiter, count) + } + DataType::Utf8 => { + let arr = string_array.as_string::(); + substr_index_scalar_impl( + arr, + delimiter, + count, + GenericStringBuilder::::with_capacity( + arr.len(), + visible_string_bytes(arr), + ), + ) + } + DataType::LargeUtf8 => { + let arr = string_array.as_string::(); + substr_index_scalar_impl( + arr, + delimiter, + count, + GenericStringBuilder::::with_capacity( + arr.len(), + visible_string_bytes(arr), + ), + ) + } + other => exec_err!("Unsupported string type {other:?} for substr_index"), + }?; + + Ok(ColumnarValue::Array(result)) +} + +#[inline] +fn visible_string_bytes( + string_array: &GenericStringArray, +) -> usize { + let offsets = string_array.value_offsets(); + offsets[offsets.len() - 1].as_usize() - offsets[0].as_usize() +} + +fn substr_index_general<'a, S, B>( + string_array: S, + delimiter_array: S, + count_array: &PrimitiveArray, + mut builder: B, ) -> Result where - T::Native: OffsetSizeTrait, + S: StringArrayType<'a> + Copy, + B: StringLikeArrayBuilder, { - let num_rows = string_array.len(); - let mut builder = GenericStringBuilder::::with_capacity(num_rows, 0); - let string_iter = ArrayIter::new(string_array); - let delimiter_array_iter = ArrayIter::new(delimiter_array); - let count_array_iter = ArrayIter::new(count_array); - string_iter - .zip(delimiter_array_iter) - .zip(count_array_iter) - .for_each(|((string, delimiter), n)| match (string, delimiter, n) { + for ((string, delimiter), n) in string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + { + match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { - // In MySQL, these cases will return an empty string. - if n == 0 || string.is_empty() || delimiter.is_empty() { - builder.append_value(""); - return; + builder.append_value(substr_index_slice(string, delimiter, n)); + } + _ => builder.append_null(), + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn substr_index_view( + string_array: &StringViewArray, + delimiter_array: &StringViewArray, + count_array: &PrimitiveArray, +) -> Result { + let nulls = NullBuffer::union( + NullBuffer::union(string_array.nulls(), delimiter_array.nulls()).as_ref(), + count_array.nulls(), + ); + let views = string_array.views(); + let mut views_buf = Vec::with_capacity(string_array.len()); + let mut has_out_of_line = false; + + for (i, raw_view) in views.iter().enumerate() { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + views_buf.push(0); + continue; + } + + let string = string_array.value(i); + let delimiter = delimiter_array.value(i); + let count = count_array.value(i); + let substr = substr_index_slice(string, delimiter, count); + has_out_of_line |= append_substr_view(&mut views_buf, raw_view, string, substr); + } + + let data_buffers = if has_out_of_line { + string_array.data_buffers().to_vec() + } else { + vec![] + }; + + // Safety: each appended view is either: + // (1) a copied null sentinel, + // (2) the original valid input view, or + // (3) built by `append_view` for a contiguous substring of the input row. + unsafe { + Ok(Arc::new(StringViewArray::new_unchecked( + ScalarBuffer::from(views_buf), + data_buffers, + nulls, + )) as ArrayRef) + } +} + +fn substr_index_scalar_impl<'a, S, B>( + string_array: S, + delimiter: &str, + count: i64, + builder: B, +) -> Result +where + S: StringArrayType<'a> + Copy, + B: StringLikeArrayBuilder, +{ + if count == 0 || delimiter.is_empty() { + return map_strings(string_array, builder, |string| &string[..0]); + } + + if delimiter.len() == 1 { + let delimiter_byte = delimiter.as_bytes()[0]; + return map_strings(string_array, builder, |string| { + substr_index_single_byte(string, delimiter_byte, count) + }); + } + + let occurrence_idx = usize::try_from(count.unsigned_abs()).unwrap_or(usize::MAX) - 1; + if count > 0 { + let finder = memmem::Finder::new(delimiter.as_bytes()); + map_strings(string_array, builder, |string| { + substr_index_slice_finder(string, &finder, delimiter.len(), occurrence_idx) + }) + } else { + let finder_rev = memmem::FinderRev::new(delimiter.as_bytes()); + map_strings(string_array, builder, |string| { + substr_index_rslice_finder( + string, + &finder_rev, + delimiter.len(), + occurrence_idx, + ) + }) + } +} + +fn substr_index_scalar_view( + string_array: &StringViewArray, + delimiter: &str, + count: i64, +) -> Result { + let views = string_array.views(); + let mut views_buf = Vec::with_capacity(string_array.len()); + let mut has_out_of_line = false; + + if count == 0 || delimiter.is_empty() { + let empty_view = make_view(b"", 0, 0); + for i in 0..string_array.len() { + if string_array.is_null(i) { + views_buf.push(0); + } else { + views_buf.push(empty_view); + } + } + } else if delimiter.len() == 1 { + let delimiter_byte = delimiter.as_bytes()[0]; + for (i, raw_view) in views.iter().enumerate() { + if string_array.is_null(i) { + views_buf.push(0); + continue; + } + + let string = string_array.value(i); + let substr = substr_index_single_byte(string, delimiter_byte, count); + has_out_of_line |= + append_substr_view(&mut views_buf, raw_view, string, substr); + } + } else { + let occurrence_idx = + usize::try_from(count.unsigned_abs()).unwrap_or(usize::MAX) - 1; + if count > 0 { + let finder = memmem::Finder::new(delimiter.as_bytes()); + for (i, raw_view) in views.iter().enumerate() { + if string_array.is_null(i) { + views_buf.push(0); + continue; } - let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - let result_idx = if delimiter.len() == 1 { - // Fast path: use byte-level search for single-character delimiters - let d_byte = delimiter.as_bytes()[0]; - let bytes = string.as_bytes(); - - if n > 0 { - bytes - .iter() - .enumerate() - .filter(|&(_, &b)| b == d_byte) - .nth(occurrences - 1) - .map(|(idx, _)| idx) - } else { - bytes - .iter() - .enumerate() - .rev() - .filter(|&(_, &b)| b == d_byte) - .nth(occurrences - 1) - .map(|(idx, _)| idx + 1) - } - } else if n > 0 { - // Multi-byte path: forward search for n-th occurrence - string - .match_indices(delimiter) - .nth(occurrences - 1) - .map(|(idx, _)| idx) - } else { - // Multi-byte path: backward search for n-th occurrence from the right - string - .rmatch_indices(delimiter) - .nth(occurrences - 1) - .map(|(idx, _)| idx + delimiter.len()) - }; - match result_idx { - Some(idx) => { - if n > 0 { - builder.append_value(&string[..idx]); - } else { - builder.append_value(&string[idx..]); - } - } - None => builder.append_value(string), + let string = string_array.value(i); + let substr = substr_index_slice_finder( + string, + &finder, + delimiter.len(), + occurrence_idx, + ); + has_out_of_line |= + append_substr_view(&mut views_buf, raw_view, string, substr); + } + } else { + let finder_rev = memmem::FinderRev::new(delimiter.as_bytes()); + for (i, raw_view) in views.iter().enumerate() { + if string_array.is_null(i) { + views_buf.push(0); + continue; } + + let string = string_array.value(i); + let substr = substr_index_rslice_finder( + string, + &finder_rev, + delimiter.len(), + occurrence_idx, + ); + has_out_of_line |= + append_substr_view(&mut views_buf, raw_view, string, substr); } - _ => builder.append_null(), - }); + } + } + + let data_buffers = if has_out_of_line { + string_array.data_buffers().to_vec() + } else { + vec![] + }; + // Safety: each appended view is either: + // (1) a copied null sentinel, + // (2) the original valid input view, + // (3) an inline empty string view, or + // (4) built by `append_view` for a contiguous substring of the input row. + unsafe { + Ok(Arc::new(StringViewArray::new_unchecked( + ScalarBuffer::from(views_buf), + data_buffers, + string_array.nulls().cloned(), + )) as ArrayRef) + } +} + +fn map_strings<'a, S, B, F>(string_array: S, mut builder: B, f: F) -> Result +where + S: StringArrayType<'a> + Copy, + B: StringLikeArrayBuilder, + F: Fn(&'a str) -> &'a str, +{ + for string in string_array.iter() { + match string { + Some(s) => builder.append_value(f(s)), + None => builder.append_null(), + } + } Ok(Arc::new(builder.finish()) as ArrayRef) } +#[inline] +fn substr_index_slice<'a>(string: &'a str, delimiter: &str, count: i64) -> &'a str { + if count == 0 || string.is_empty() || delimiter.is_empty() { + return &string[..0]; + } + + if delimiter.len() == 1 { + return substr_index_single_byte(string, delimiter.as_bytes()[0], count); + } + + let occurrences = usize::try_from(count.unsigned_abs()).unwrap_or(usize::MAX); + if count > 0 { + string + .match_indices(delimiter) + .nth(occurrences - 1) + .map(|(idx, _)| &string[..idx]) + .unwrap_or(string) + } else { + string + .rmatch_indices(delimiter) + .nth(occurrences - 1) + .map(|(idx, _)| &string[idx + delimiter.len()..]) + .unwrap_or(string) + } +} + +#[inline] +fn substr_index_single_byte(string: &str, delimiter: u8, count: i64) -> &str { + let occurrences = usize::try_from(count.unsigned_abs()).unwrap_or(usize::MAX); + let idx = if count > 0 { + memchr_iter(delimiter, string.as_bytes()).nth(occurrences - 1) + } else { + memrchr_iter(delimiter, string.as_bytes()) + .nth(occurrences - 1) + .map(|idx| idx + 1) + }; + + match idx { + Some(idx) if count > 0 => &string[..idx], + Some(idx) => &string[idx..], + None => string, + } +} + +#[inline] +fn substr_index_slice_finder<'a>( + string: &'a str, + finder: &memmem::Finder, + delimiter_len: usize, + occurrence_idx: usize, +) -> &'a str { + let bytes = string.as_bytes(); + let mut start = 0; + for _ in 0..occurrence_idx { + match finder.find(&bytes[start..]) { + Some(pos) => start += pos + delimiter_len, + None => return string, + } + } + + match finder.find(&bytes[start..]) { + Some(pos) => &string[..start + pos], + None => string, + } +} + +#[inline] +fn substr_index_rslice_finder<'a>( + string: &'a str, + finder: &memmem::FinderRev, + delimiter_len: usize, + occurrence_idx: usize, +) -> &'a str { + let bytes = string.as_bytes(); + let mut end = bytes.len(); + for _ in 0..occurrence_idx { + match finder.rfind(&bytes[..end]) { + Some(pos) => end = pos, + None => return string, + } + } + + match finder.rfind(&bytes[..end]) { + Some(pos) => &string[pos + delimiter_len..], + None => string, + } +} + +#[inline] +fn substr_view(original_view: &u128, substr: &str, start_offset: u32) -> u128 { + if substr.len() > 12 { + let view = ByteView::from(*original_view); + make_view( + substr.as_bytes(), + view.buffer_index, + view.offset + start_offset, + ) + } else { + make_view(substr.as_bytes(), 0, 0) + } +} + +#[inline] +fn append_substr_view( + views_buf: &mut Vec, + raw_view: &u128, + string: &str, + substr: &str, +) -> bool { + if substr.len() == string.len() { + views_buf.push(*raw_view); + return substr.len() > 12; + } + + if substr.is_empty() { + views_buf.push(make_view(b"", 0, 0)); + return false; + } + + let start_offset = substr.as_ptr() as usize - string.as_ptr() as usize; + let start_offset = + u32::try_from(start_offset).expect("string view offsets fit in u32"); + views_buf.push(substr_view(raw_view, substr, start_offset)); + substr.len() > 12 +} + #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{ + Array, ArrayRef, AsArray, Int64Array, StringArray, StringViewArray, + }; + use arrow::datatypes::DataType::{Utf8, Utf8View}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; use crate::unicode::substrindex::SubstrIndexFunc; use crate::utils::test::test_function; @@ -340,6 +708,136 @@ mod tests { Utf8, StringArray ); + test_function!( + SubstrIndexFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "verylongprefix.segment.tail".into(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(".".into()))), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("verylongprefix")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrIndexFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "www.apache.org".into(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(".".into()))), + ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ], + Ok(Some("org")), + &str, + Utf8View, + StringViewArray + ); + Ok(()) + } + + #[test] + fn test_substr_index_utf8view_scalar_fast_path() -> Result<()> { + let input = Arc::new(StringViewArray::from(vec![ + Some("alpha.beta.gamma"), + Some("short.val"), + None, + ])) as ArrayRef; + + let arg_fields = vec![ + Field::new("a", Utf8View, true).into(), + Field::new("b", Utf8View, true).into(), + Field::new("c", DataType::Int64, true).into(), + ]; + + let args = ScalarFunctionArgs { + number_rows: input.len(), + args: vec![ + ColumnarValue::Array(Arc::clone(&input)), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(".".into()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + arg_fields, + return_field: Field::new("f", Utf8View, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = match SubstrIndexFunc::new().invoke_with_args(args)? { + ColumnarValue::Array(result) => result, + other => panic!("expected array result, got {other:?}"), + }; + let result = result.as_string_view(); + + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), "alpha"); + assert_eq!(result.value(1), "short"); + assert!(result.is_null(2)); + + Ok(()) + } + + #[test] + fn test_substr_index_utf8view_array_sliced() -> Result<()> { + use super::substr_index_view; + + let strings: StringViewArray = vec![ + Some("skip_this.value"), + Some("this_is_a_long_prefix.suffix"), + Some("short.val"), + Some("another_long_result.rest"), + None, + ] + .into_iter() + .collect(); + let delimiters: StringViewArray = + vec![Some("."), Some("."), Some("."), Some("."), Some(".")] + .into_iter() + .collect(); + let counts = Int64Array::from(vec![1, 1, -1, 1, 1]); + + let sliced_strings = strings.slice(1, 4); + let sliced_delimiters = delimiters.slice(1, 4); + let sliced_counts = counts.slice(1, 4); + + let result = + substr_index_view(&sliced_strings, &sliced_delimiters, &sliced_counts)?; + let result = result.as_string_view(); + + assert_eq!(result.len(), 4); + assert_eq!(result.value(0), "this_is_a_long_prefix"); + assert_eq!(result.value(1), "val"); + assert_eq!(result.value(2), "another_long_result"); + assert!(result.is_null(3)); + + Ok(()) + } + + #[test] + fn test_substr_index_utf8view_scalar_reuses_original_view_when_unchanged() + -> Result<()> { + use super::substr_index_scalar_view; + + let strings: StringViewArray = vec![ + Some("very_long_value_without_separator"), + Some("short"), + None, + ] + .into_iter() + .collect(); + + let result = substr_index_scalar_view(&strings, ".", 1)?; + let result = result.as_string_view(); + + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), "very_long_value_without_separator"); + assert_eq!(result.value(1), "short"); + assert_eq!(result.views()[0], strings.views()[0]); + assert_eq!(result.views()[1], strings.views()[1]); + assert!(result.is_null(2)); + Ok(()) } } diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 126c4bcafb533..4de2d20b02e8a 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -1078,6 +1078,28 @@ logical_plan 01)Projection: substr_index(test.column1_utf8view, Utf8View("a"), Int64(1)) AS c, substr_index(test.column1_utf8view, Utf8View("a"), Int64(2)) AS c2 02)--TableScan: test projection=[column1_utf8view] +## Verify Utf8View input produces Utf8View output for SUBSTR_INDEX +query T +SELECT arrow_typeof(substr_index(arrow_cast('a.b.c', 'Utf8View'), '.', 1)); +---- +Utf8View + +## Verify array path also returns Utf8View for SUBSTR_INDEX +query T +SELECT arrow_typeof(substr_index(column1_utf8view, 'a', 1)) FROM test LIMIT 1; +---- +Utf8View + +## Verify array path values for SUBSTR_INDEX with Utf8View input +query T +SELECT substr_index(column1_utf8view, 'a', 1) FROM test; +---- +Andrew +Xi +R +(empty) +NULL + ## Ensure no casts on columns for STARTS_WITH query TT