Skip to content

Commit 8ea71fd

Browse files
committed
revised PR
1 parent 3e60f15 commit 8ea71fd

File tree

4 files changed

+342
-399
lines changed

4 files changed

+342
-399
lines changed

datafusion/physical-plan/src/aggregates/group_values/mod.rs

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ use crate::aggregates::{
4646
},
4747
order::GroupOrdering,
4848
};
49-
use arrow::array::*;
50-
use std::sync::Arc;
5149

5250
mod metrics;
5351
mod null_builder;
@@ -140,9 +138,6 @@ pub fn new_group_values(
140138
) -> Result<Box<dyn GroupValues>> {
141139
if schema.fields.len() == 1 {
142140
let d = schema.fields[0].data_type();
143-
println!(
144-
"[should be dictionary encoded] single column group by with data type: {d:#?}"
145-
);
146141

147142
macro_rules! downcast_helper {
148143
($t:ty, $d:ident) => {
@@ -204,18 +199,50 @@ pub fn new_group_values(
204199
}
205200
DataType::Dictionary(key_type, value_type) => {
206201
if supported_single_dictionary_value(value_type) {
207-
println!("dictionary type detected, using GroupValuesDictionary");
208-
return match key_type.as_ref() { // TODO: turn this into a macro
209-
DataType::Int8 => Ok(Box::new(GroupValuesDictionary::<arrow::datatypes::Int8Type>::new(value_type))),
210-
DataType::Int16 => Ok(Box::new(GroupValuesDictionary::<arrow::datatypes::Int16Type>::new(value_type))),
211-
DataType::Int32 => Ok(Box::new(GroupValuesDictionary::<arrow::datatypes::Int32Type>::new(value_type))),
212-
DataType::Int64 => Ok(Box::new(GroupValuesDictionary::<arrow::datatypes::Int64Type>::new(value_type))),
213-
DataType::UInt8 => Ok(Box::new(GroupValuesDictionary::<arrow::datatypes::UInt8Type>::new(value_type))),
214-
DataType::UInt16 => Ok(Box::new(GroupValuesDictionary::<arrow::datatypes::UInt16Type>::new(value_type))),
215-
DataType::UInt32 => Ok(Box::new(GroupValuesDictionary::<arrow::datatypes::UInt32Type>::new(value_type))),
216-
DataType::UInt64 => Ok(Box::new(GroupValuesDictionary::<arrow::datatypes::UInt64Type>::new(value_type))),
202+
return match key_type.as_ref() {
203+
// TODO: turn this into a macro
204+
DataType::Int8 => {
205+
Ok(Box::new(GroupValuesDictionary::<
206+
arrow::datatypes::Int8Type,
207+
>::new(value_type)))
208+
}
209+
DataType::Int16 => {
210+
Ok(Box::new(GroupValuesDictionary::<
211+
arrow::datatypes::Int16Type,
212+
>::new(value_type)))
213+
}
214+
DataType::Int32 => {
215+
Ok(Box::new(GroupValuesDictionary::<
216+
arrow::datatypes::Int32Type,
217+
>::new(value_type)))
218+
}
219+
DataType::Int64 => {
220+
Ok(Box::new(GroupValuesDictionary::<
221+
arrow::datatypes::Int64Type,
222+
>::new(value_type)))
223+
}
224+
DataType::UInt8 => {
225+
Ok(Box::new(GroupValuesDictionary::<
226+
arrow::datatypes::UInt8Type,
227+
>::new(value_type)))
228+
}
229+
DataType::UInt16 => {
230+
Ok(Box::new(GroupValuesDictionary::<
231+
arrow::datatypes::UInt16Type,
232+
>::new(value_type)))
233+
}
234+
DataType::UInt32 => {
235+
Ok(Box::new(GroupValuesDictionary::<
236+
arrow::datatypes::UInt32Type,
237+
>::new(value_type)))
238+
}
239+
DataType::UInt64 => {
240+
Ok(Box::new(GroupValuesDictionary::<
241+
arrow::datatypes::UInt64Type,
242+
>::new(value_type)))
243+
}
217244
_ => Err(datafusion_common::DataFusionError::NotImplemented(
218-
format!("Unsupported dictionary key type: {:?}", key_type)
245+
format!("Unsupported dictionary key type: {:?}", key_type),
219246
)),
220247
};
221248
}
@@ -255,4 +282,3 @@ fn supported_single_dictionary_value(t: &DataType) -> bool {
255282
| DataType::UInt64
256283
)
257284
}
258-

datafusion/physical-plan/src/aggregates/group_values/row.rs

Lines changed: 0 additions & 215 deletions
Original file line numberDiff line numberDiff line change
@@ -324,218 +324,3 @@ fn dictionary_encode_if_necessary(
324324
(_, _) => Ok(Arc::<dyn Array>::clone(array)),
325325
}
326326
}
327-
328-
mod playground {
329-
use std::ops::Index;
330-
331-
use arrow::array::{AsArray, Datum};
332-
use datafusion_execution::TaskContext;
333-
334-
use crate::{ExecutionPlan, test::TestMemoryExec};
335-
336-
use super::*;
337-
use arrow::array::{Array, ArrayRef, StringArray};
338-
use std::string::String;
339-
struct TrivialGroupBy {
340-
seen_strings: Vec<String>,
341-
cur_size: usize,
342-
}
343-
impl GroupValues for TrivialGroupBy {
344-
// trivial apprach assume theres only one columns
345-
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
346-
let n_rows = cols[0].len();
347-
let column_count = cols.len();
348-
// iterate all the rows, for each row, concate each value into a final string
349-
for row_idx in 0..n_rows {
350-
// grab the underlying value; assume its a string
351-
let mut cur_str = String::from("");
352-
for col_idx in 0..column_count {
353-
let array = cols.get(col_idx).unwrap();
354-
let string_array =
355-
array.as_any().downcast_ref::<StringArray>().unwrap();
356-
if string_array.is_valid(row_idx) {
357-
let value = string_array.value(row_idx);
358-
cur_str.push_str(value);
359-
}
360-
}
361-
let idx = if let Some(i) =
362-
self.seen_strings.iter().position(|x| x == &cur_str)
363-
{
364-
i
365-
} else {
366-
self.seen_strings.push(cur_str);
367-
self.seen_strings.len() - 1
368-
};
369-
groups.push(idx);
370-
}
371-
println!("{:?}", self.seen_strings);
372-
Ok(())
373-
}
374-
fn len(&self) -> usize {
375-
self.seen_strings.len()
376-
}
377-
fn is_empty(&self) -> bool {
378-
self.seen_strings.is_empty()
379-
}
380-
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
381-
let strings_to_emit = match emit_to {
382-
EmitTo::All => {
383-
// take all groups, clear internal state
384-
std::mem::take(&mut self.seen_strings)
385-
}
386-
EmitTo::First(n) => {
387-
// take only the first n groups
388-
// drain removes them from seen_strings
389-
self.seen_strings.drain(..n).collect()
390-
}
391-
};
392-
393-
// convert our Vec<String> back into an Arrow StringArray
394-
let array: ArrayRef = Arc::new(StringArray::from(
395-
strings_to_emit
396-
.iter()
397-
.map(|s| s.as_str())
398-
.collect::<Vec<_>>(),
399-
));
400-
401-
// return one array per GROUP BY column
402-
// since we're trivial and only support one column, just return one
403-
Ok(vec![array])
404-
}
405-
406-
fn size(&self) -> usize {
407-
self.cur_size
408-
}
409-
fn clear_shrink(&mut self, num_rows: usize) {
410-
let _ = num_rows;
411-
}
412-
}
413-
414-
#[test]
415-
fn test_trivial_group_by_single_column() {
416-
// Test grouping on a single string column
417-
let strings = vec!["apple", "banana", "apple", "cherry", "banana"];
418-
let array: ArrayRef = Arc::new(StringArray::from(strings));
419-
420-
let mut group_by = TrivialGroupBy {
421-
seen_strings: Vec::new(),
422-
cur_size: 0,
423-
};
424-
425-
// Intern the group keys
426-
let mut groups = Vec::new();
427-
group_by.intern(&[array], &mut groups).unwrap();
428-
429-
// Should have assigned group ids: 0, 1, 0, 2, 1
430-
assert_eq!(groups, vec![0, 1, 0, 2, 1]);
431-
assert_eq!(group_by.len(), 3); // apple, banana, cherry
432-
433-
// Emit all groups
434-
let emitted = group_by.emit(EmitTo::All).unwrap();
435-
assert_eq!(emitted.len(), 1); // One column
436-
let emitted_array = emitted[0].as_any().downcast_ref::<StringArray>().unwrap();
437-
assert_eq!(emitted_array.value(0), "apple");
438-
assert_eq!(emitted_array.value(1), "banana");
439-
assert_eq!(emitted_array.value(2), "cherry");
440-
}
441-
442-
#[test]
443-
fn test_trivial_group_by_two_columns() {
444-
// Test grouping on two string columns
445-
let col1 = vec!["a", "a", "b", "a", "b"];
446-
let col2 = vec!["x", "y", "x", "x", "y"];
447-
448-
let array1: ArrayRef = Arc::new(StringArray::from(col1));
449-
let array2: ArrayRef = Arc::new(StringArray::from(col2));
450-
451-
let mut group_by = TrivialGroupBy {
452-
seen_strings: Vec::new(),
453-
cur_size: 0,
454-
};
455-
456-
// Intern: concatenates ("a" + "x"), ("a" + "y"), ("b" + "x"), etc.
457-
let mut groups = Vec::new();
458-
group_by.intern(&[array1, array2], &mut groups).unwrap();
459-
460-
// Should have 4 distinct groups: "ax", "ay", "bx", "by"
461-
assert_eq!(group_by.len(), 4);
462-
assert_eq!(groups, vec![0, 1, 2, 0, 3]); // group ids assigned
463-
464-
// Emit all groups
465-
let emitted = group_by.emit(EmitTo::All).unwrap();
466-
assert_eq!(emitted.len(), 1); // One output column (concatenated strings)
467-
let emitted_array = emitted[0].as_any().downcast_ref::<StringArray>().unwrap();
468-
assert_eq!(emitted_array.value(0), "ax");
469-
assert_eq!(emitted_array.value(1), "ay");
470-
assert_eq!(emitted_array.value(2), "bx");
471-
assert_eq!(emitted_array.value(3), "by");
472-
}
473-
474-
#[tokio::test]
475-
async fn test_trivial_group_by_dictionary() -> Result<()> {
476-
use crate::aggregates::RecordBatch;
477-
use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
478-
use crate::common::collect;
479-
use crate::test::TestMemoryExec;
480-
use arrow::array::DictionaryArray;
481-
use arrow::datatypes::{DataType, Field, Schema};
482-
use datafusion_functions_aggregate::count::count_udaf;
483-
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
484-
use datafusion_physical_expr::expressions::col;
485-
486-
// Create schema with dictionary column and value column
487-
let schema = Arc::new(Schema::new(vec![
488-
Field::new(
489-
"color",
490-
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
491-
false,
492-
),
493-
Field::new("amount", DataType::UInt32, false),
494-
]));
495-
496-
// Create dictionary array
497-
let values = StringArray::from(vec!["red", "blue", "green"]);
498-
let keys = arrow::array::UInt8Array::from(vec![0, 1, 0, 2, 1]);
499-
let dict_array: ArrayRef = Arc::new(DictionaryArray::<
500-
arrow::datatypes::UInt8Type,
501-
>::try_new(keys, Arc::new(values))?);
502-
503-
// Create value column
504-
let amount_array: ArrayRef =
505-
Arc::new(arrow::array::UInt32Array::from(vec![1, 2, 3, 4, 5]));
506-
507-
// Create batch
508-
let batch =
509-
RecordBatch::try_new(Arc::clone(&schema), vec![dict_array, amount_array])?;
510-
511-
// Create in-memory source with the batch
512-
let source =
513-
TestMemoryExec::try_new(&vec![vec![batch]], Arc::clone(&schema), None)?;
514-
515-
// Create GROUP BY expression
516-
let group_expr = vec![(col("color", &schema)?, "color".to_string())];
517-
518-
// Create COUNT(amount) aggregate expression
519-
let aggr_expr = vec![Arc::new(
520-
AggregateExprBuilder::new(count_udaf(), vec![col("amount", &schema)?])
521-
.schema(Arc::clone(&schema))
522-
.alias("count_amount")
523-
.build()?,
524-
)];
525-
526-
// Create AggregateExec
527-
let aggregate_exec = AggregateExec::try_new(
528-
AggregateMode::SinglePartitioned,
529-
PhysicalGroupBy::new_single(group_expr),
530-
aggr_expr,
531-
vec![None],
532-
Arc::new(source),
533-
Arc::clone(&schema),
534-
)?;
535-
536-
let output =
537-
collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
538-
println!("Output batch: {:#?}", output);
539-
Ok(())
540-
}
541-
}

0 commit comments

Comments
 (0)