Skip to content

Commit bd8f1ad

Browse files
committed
feat(aggregate): simplify min/max helper and enhance testing for Dictionary(Int8, Utf8)
- Removed redundant `dictionary_inner_scalar_min_max` helper and invoked `min_max_scalar(...)` directly in the dictionary match arms to streamline code. - Added end-to-end aggregate test for `Dictionary(Int8, Utf8)` via `test_min_max_dictionary_int8_keys`. - Introduced a generic test helper for building string dictionaries with various key types, reducing setup duplication.
1 parent 47f75b2 commit bd8f1ad

2 files changed

Lines changed: 32 additions & 30 deletions

File tree

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -428,26 +428,15 @@ macro_rules! min_max_scalar_impl {
428428
);
429429
}
430430

431-
let result = dictionary_inner_scalar_min_max(
432-
lhs.as_ref(),
433-
rhs.as_ref(),
434-
choose_min_max!($OP),
435-
)?;
431+
let result =
432+
min_max_scalar(lhs.as_ref(), rhs.as_ref(), choose_min_max!($OP))?;
436433
ScalarValue::Dictionary(lhs_key_type.clone(), Box::new(result))
437434
}
438435
(ScalarValue::Dictionary(_, lhs), rhs) => {
439-
dictionary_inner_scalar_min_max(
440-
lhs.as_ref(),
441-
rhs,
442-
choose_min_max!($OP),
443-
)?
436+
min_max_scalar(lhs.as_ref(), rhs, choose_min_max!($OP))?
444437
}
445438
(lhs, ScalarValue::Dictionary(_, rhs)) => {
446-
dictionary_inner_scalar_min_max(
447-
lhs,
448-
rhs.as_ref(),
449-
choose_min_max!($OP),
450-
)?
439+
min_max_scalar(lhs, rhs.as_ref(), choose_min_max!($OP))?
451440
}
452441

453442
e => {
@@ -510,14 +499,6 @@ fn min_max_batch_generic(values: &ArrayRef, ordering: Ordering) -> Result<Scalar
510499
Ok(extreme)
511500
}
512501

513-
fn dictionary_inner_scalar_min_max(
514-
lhs: &ScalarValue,
515-
rhs: &ScalarValue,
516-
ordering: Ordering,
517-
) -> Result<ScalarValue> {
518-
min_max_scalar(lhs, rhs, ordering)
519-
}
520-
521502
/// An accumulator to compute the maximum value
522503
#[derive(Debug, Clone)]
523504
pub struct MaxAccumulator {

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,12 +1004,13 @@ mod tests {
10041004
use super::*;
10051005
use arrow::{
10061006
array::{
1007-
DictionaryArray, Float32Array, Int32Array, IntervalDayTimeArray,
1008-
IntervalMonthDayNanoArray, IntervalYearMonthArray, StringArray,
1007+
Array, DictionaryArray, Float32Array, Int32Array, Int8Array,
1008+
IntervalDayTimeArray, IntervalMonthDayNanoArray, PrimitiveArray,
1009+
IntervalYearMonthArray, StringArray,
10091010
},
10101011
datatypes::{
1011-
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
1012-
IntervalYearMonthType,
1012+
ArrowDictionaryKeyType, IntervalDayTimeType, IntervalMonthDayNanoType,
1013+
IntervalUnit, IntervalYearMonthType,
10131014
},
10141015
};
10151016
use std::sync::Arc;
@@ -1272,10 +1273,18 @@ mod tests {
12721273
}
12731274

12741275
fn string_dictionary_batch(values: &[&str], keys: &[Option<i32>]) -> ArrayRef {
1276+
string_dictionary_batch_with_keys(Int32Array::from(keys.to_vec()), values)
1277+
}
1278+
1279+
fn string_dictionary_batch_with_keys<K>(
1280+
keys: PrimitiveArray<K>,
1281+
values: &[&str],
1282+
) -> ArrayRef
1283+
where
1284+
K: ArrowDictionaryKeyType,
1285+
{
12751286
let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef;
1276-
Arc::new(
1277-
DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(),
1278-
) as ArrayRef
1287+
Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef
12791288
}
12801289

12811290
fn optional_string_dictionary_batch(
@@ -1383,6 +1392,18 @@ mod tests {
13831392
assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d")
13841393
}
13851394

1395+
#[test]
1396+
fn test_min_max_dictionary_int8_keys() -> Result<()> {
1397+
let dict_type =
1398+
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
1399+
let dict_array_ref = string_dictionary_batch_with_keys(
1400+
Int8Array::from(vec![Some(0), Some(1), Some(2), Some(3)]),
1401+
&["b", "c", "a", "d"],
1402+
);
1403+
1404+
assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d")
1405+
}
1406+
13861407
#[test]
13871408
fn test_min_max_dictionary_float_with_nans() -> Result<()> {
13881409
let dict_type =

0 commit comments

Comments
 (0)