Skip to content

Commit 7127c20

Browse files
committed
Fix MIN/MAX to preserve dictionary types in execution
Update get_min_max_result_type to maintain Dictionary<K, V> instead of unwrapping to V, allowing planned MIN/MAX execution to utilize the dictionary-aware accumulator. Add end-to-end SQL regression test to ensure MIN/MAX properly ignores unreferenced dictionary values and preserves dictionary-typed output schema. Adjust unit expectations for dictionary coercion tests to reflect new planned-path behavior.
1 parent 9240400 commit 7127c20

2 files changed

Lines changed: 63 additions & 15 deletions

File tree

datafusion/core/tests/sql/aggregates/basic.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,44 @@ async fn count_distinct_dictionary_mixed_values() -> Result<()> {
442442
Ok(())
443443
}
444444

445+
#[tokio::test]
446+
async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> {
447+
let ctx = SessionContext::new();
448+
449+
let dict_values = StringArray::from(vec!["a", "z", "zz_unused"]);
450+
let dict_indices = Int32Array::from(vec![Some(1), Some(1), None]);
451+
let dict = DictionaryArray::new(dict_indices, Arc::new(dict_values));
452+
453+
let dict_type =
454+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
455+
let schema = Arc::new(Schema::new(vec![Field::new("dict", dict_type.clone(), true)]));
456+
457+
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(dict)])?;
458+
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
459+
ctx.register_table("t", Arc::new(provider))?;
460+
461+
let df = ctx
462+
.sql("SELECT min(dict) AS min_dict, max(dict) AS max_dict FROM t")
463+
.await?;
464+
let results = df.collect().await?;
465+
466+
assert_eq!(results[0].schema().field(0).data_type(), &dict_type);
467+
assert_eq!(results[0].schema().field(1).data_type(), &dict_type);
468+
469+
assert_snapshot!(
470+
batches_to_string(&results),
471+
@r"
472+
+----------+----------+
473+
| min_dict | max_dict |
474+
+----------+----------+
475+
| z | z |
476+
+----------+----------+
477+
"
478+
);
479+
480+
Ok(())
481+
}
482+
445483
#[tokio::test]
446484
async fn group_by_ree_dict_column() -> Result<()> {
447485
let ctx = SessionContext::new();

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
5353
use datafusion_macros::user_doc;
5454
use half::f16;
5555
use std::mem::size_of_val;
56-
use std::ops::Deref;
5756

5857
fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
5958
// make sure that the input types only has one element.
@@ -63,17 +62,12 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
6362
input_types.len()
6463
);
6564
}
66-
// min and max support the dictionary data type
67-
// unpack the dictionary to get the value
68-
match &input_types[0] {
69-
DataType::Dictionary(_, dict_value_type) => {
70-
// TODO add checker, if the value type is complex data type
71-
Ok(vec![dict_value_type.deref().clone()])
72-
}
73-
// TODO add checker for datatype which min and max supported
74-
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
75-
_ => Ok(input_types.to_vec()),
76-
}
65+
// Preserve dictionary inputs so planned MIN/MAX execution uses the same
66+
// dictionary-aware accumulator/state path as direct accumulator tests.
67+
//
68+
// TODO add checker for datatype which min and max supported.
69+
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function.
70+
Ok(input_types.to_vec())
7771
}
7872

7973
#[user_doc(
@@ -1223,6 +1217,10 @@ mod tests {
12231217
vec![DataType::Decimal128(10, 2)],
12241218
vec![DataType::Decimal256(1, 1)],
12251219
vec![DataType::Utf8],
1220+
vec![DataType::Dictionary(
1221+
Box::new(DataType::Int32),
1222+
Box::new(DataType::Utf8),
1223+
)],
12261224
];
12271225
for fun in funs {
12281226
for input_type in &input_types {
@@ -1237,7 +1235,13 @@ mod tests {
12371235
let data_type =
12381236
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
12391237
let result = get_min_max_result_type(&[data_type])?;
1240-
assert_eq!(result, vec![DataType::Utf8]);
1238+
assert_eq!(
1239+
result,
1240+
vec![DataType::Dictionary(
1241+
Box::new(DataType::Int32),
1242+
Box::new(DataType::Utf8),
1243+
)]
1244+
);
12411245
Ok(())
12421246
}
12431247

@@ -1254,12 +1258,18 @@ mod tests {
12541258
let mut min_acc = MinAccumulator::try_new(&rt_type)?;
12551259
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
12561260
let min_result = min_acc.evaluate()?;
1257-
assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string())));
1261+
assert_eq!(
1262+
min_result,
1263+
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string())))
1264+
);
12581265

12591266
let mut max_acc = MaxAccumulator::try_new(&rt_type)?;
12601267
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
12611268
let max_result = max_acc.evaluate()?;
1262-
assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string())));
1269+
assert_eq!(
1270+
max_result,
1271+
dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string())))
1272+
);
12631273
Ok(())
12641274
}
12651275

0 commit comments

Comments
 (0)