Skip to content

Commit 7bd29e1

Browse files
committed
feat: enhance dictionary comparison logic and add unit tests
- Improved comparison behavior between dictionaries: - Same key type dictionaries now compare inner logical values and rewrap the result. - Different key type dictionaries raise an explicit internal error. - Comparisons between dictionaries and non-dictionary types now check inner logical values directly. - Updated and tightened the macro comment to clarify mixed-type dictionary support limitations. - Added focused unit tests for: - Dictionary vs scalar comparison - Same-key dictionary rewrapping - Mismatched dictionary key types - Incompatible dictionary and plain scalar comparisons
1 parent 150bc6f commit 7bd29e1

1 file changed

Lines changed: 128 additions & 33 deletions

File tree

  • datafusion/functions-aggregate-common/src

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 128 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ macro_rules! min_max_generic {
142142
}
143143

144144
// min/max of two logically compatible scalar values.
145-
// Dictionary scalars are unwrapped to their inner values for comparison,
146-
// then rewrapped with the dictionary key type when both inputs are dictionaries
147-
// after validating that their key types match.
145+
// Dictionary scalars participate by comparing their inner logical values.
146+
// When both inputs are dictionaries, matching key types are preserved in the
147+
// result; differing key types remain an unexpected invariant violation.
148148
macro_rules! min_max {
149149
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
150150
Ok(match ($VALUE, $DELTA) {
@@ -416,31 +416,38 @@ macro_rules! min_max {
416416
min_max_generic!(lhs, rhs, $OP)
417417
}
418418

419-
(lhs, rhs)
420-
if matches!(lhs, ScalarValue::Dictionary(_, _))
421-
|| matches!(rhs, ScalarValue::Dictionary(_, _)) =>
422-
{
423-
let (lhs, lhs_key_type) = dictionary_scalar_parts(lhs);
424-
let (rhs, rhs_key_type) = dictionary_scalar_parts(rhs);
425-
let result = min_max_generic!(lhs, rhs, $OP);
426-
427-
match lhs_key_type.zip(rhs_key_type) {
428-
Some((lhs_key_type, rhs_key_type)) => {
429-
if lhs_key_type != rhs_key_type {
430-
return internal_err!(
431-
"MIN/MAX is not expected to receive dictionary scalars with different key types ({:?} vs {:?})",
432-
lhs_key_type,
433-
rhs_key_type
434-
);
435-
}
436-
437-
ScalarValue::Dictionary(
438-
Box::new(lhs_key_type.clone()),
439-
Box::new(result),
440-
)
441-
}
442-
None => result,
419+
(
420+
ScalarValue::Dictionary(lhs_key_type, lhs),
421+
ScalarValue::Dictionary(rhs_key_type, rhs),
422+
) => {
423+
if lhs_key_type != rhs_key_type {
424+
return internal_err!(
425+
"MIN/MAX is not expected to receive dictionary scalars with different key types ({:?} vs {:?})",
426+
lhs_key_type,
427+
rhs_key_type
428+
);
443429
}
430+
431+
let result = dictionary_inner_scalar_min_max(
432+
lhs.as_ref(),
433+
rhs.as_ref(),
434+
choose_min_max!($OP),
435+
)?;
436+
ScalarValue::Dictionary(lhs_key_type.clone(), Box::new(result))
437+
}
438+
(ScalarValue::Dictionary(_, lhs), rhs) => {
439+
dictionary_inner_scalar_min_max(
440+
lhs.as_ref(),
441+
rhs,
442+
choose_min_max!($OP),
443+
)?
444+
}
445+
(lhs, ScalarValue::Dictionary(_, rhs)) => {
446+
dictionary_inner_scalar_min_max(
447+
lhs,
448+
rhs.as_ref(),
449+
choose_min_max!($OP),
450+
)?
444451
}
445452

446453
e => {
@@ -485,12 +492,15 @@ fn min_max_batch_generic(values: &ArrayRef, ordering: Ordering) -> Result<Scalar
485492
Ok(extreme)
486493
}
487494

488-
fn dictionary_scalar_parts(value: &ScalarValue) -> (&ScalarValue, Option<&DataType>) {
489-
match value {
490-
ScalarValue::Dictionary(key_type, inner) => {
491-
(inner.as_ref(), Some(key_type.as_ref()))
492-
}
493-
other => (other, None),
495+
fn dictionary_inner_scalar_min_max(
496+
lhs: &ScalarValue,
497+
rhs: &ScalarValue,
498+
ordering: Ordering,
499+
) -> Result<ScalarValue> {
500+
match ordering {
501+
Ordering::Greater => min_max!(lhs, rhs, min),
502+
Ordering::Less => min_max!(lhs, rhs, max),
503+
Ordering::Equal => unreachable!("min/max comparisons do not use equal ordering"),
494504
}
495505
}
496506

@@ -893,3 +903,88 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
893903
_ => min_max_batch!(values, max),
894904
})
895905
}
906+
907+
#[cfg(test)]
908+
mod tests {
909+
use super::*;
910+
911+
#[test]
912+
fn min_max_dictionary_and_scalar_compare_by_inner_value() -> Result<()> {
913+
let dictionary = ScalarValue::Dictionary(
914+
Box::new(DataType::Int32),
915+
Box::new(ScalarValue::Float32(Some(1.0))),
916+
);
917+
let scalar = ScalarValue::Float32(Some(2.0));
918+
919+
let result: Result<ScalarValue, DataFusionError> =
920+
min_max!(&dictionary, &scalar, max);
921+
let result = result?;
922+
923+
assert_eq!(result, ScalarValue::Float32(Some(2.0)));
924+
Ok(())
925+
}
926+
927+
#[test]
928+
fn min_max_dictionary_same_key_type_rewraps_result() -> Result<()> {
929+
let lhs = ScalarValue::Dictionary(
930+
Box::new(DataType::Int32),
931+
Box::new(ScalarValue::Float32(Some(1.0))),
932+
);
933+
let rhs = ScalarValue::Dictionary(
934+
Box::new(DataType::Int32),
935+
Box::new(ScalarValue::Float32(Some(2.0))),
936+
);
937+
938+
let result: Result<ScalarValue, DataFusionError> = min_max!(&lhs, &rhs, max);
939+
let result = result?;
940+
941+
assert_eq!(
942+
result,
943+
ScalarValue::Dictionary(
944+
Box::new(DataType::Int32),
945+
Box::new(ScalarValue::Float32(Some(2.0))),
946+
)
947+
);
948+
Ok(())
949+
}
950+
951+
#[test]
952+
fn min_max_dictionary_different_key_types_error() -> Result<()> {
953+
let lhs = ScalarValue::Dictionary(
954+
Box::new(DataType::Int8),
955+
Box::new(ScalarValue::Float32(Some(1.0))),
956+
);
957+
let rhs = ScalarValue::Dictionary(
958+
Box::new(DataType::Int32),
959+
Box::new(ScalarValue::Float32(Some(2.0))),
960+
);
961+
962+
let error: DataFusionError = min_max!(&lhs, &rhs, max).unwrap_err();
963+
964+
assert!(
965+
error
966+
.to_string()
967+
.contains("dictionary scalars with different key types")
968+
);
969+
Ok(())
970+
}
971+
972+
#[test]
973+
fn min_max_dictionary_and_incompatible_scalar_error() -> Result<()> {
974+
let dictionary = ScalarValue::Dictionary(
975+
Box::new(DataType::Int32),
976+
Box::new(ScalarValue::Float32(Some(1.0))),
977+
);
978+
let scalar = ScalarValue::Int32(Some(2));
979+
980+
let error: DataFusionError =
981+
min_max!(&dictionary, &scalar, max).unwrap_err();
982+
983+
assert!(
984+
error
985+
.to_string()
986+
.contains("logically incompatible scalar values")
987+
);
988+
Ok(())
989+
}
990+
}

0 commit comments

Comments
 (0)