Skip to content

Commit aaf9ca7

Browse files
committed
Researching code flow
1 parent 37cd3de commit aaf9ca7

5 files changed

Lines changed: 815 additions & 2 deletions

File tree

datafusion/physical-plan/src/aggregates/group_values/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pub(crate) use single_group_by::primitive::HashValue;
4141
use crate::aggregates::{
4242
group_values::single_group_by::{
4343
boolean::GroupValuesBoolean, bytes::GroupValuesBytes,
44-
bytes_view::GroupValuesBytesView, primitive::GroupValuesPrimitive,
44+
bytes_view::GroupValuesBytesView, primitive::GroupValuesPrimitive, dictionary::GroupValuesDictionary,
4545
},
4646
order::GroupOrdering,
4747
};
@@ -137,6 +137,9 @@ pub fn new_group_values(
137137
) -> Result<Box<dyn GroupValues>> {
138138
if schema.fields.len() == 1 {
139139
let d = schema.fields[0].data_type();
140+
println!(
141+
"[should be dictionary encoded] single column group by with data type: {d:#?}"
142+
);
140143

141144
macro_rules! downcast_helper {
142145
($t:ty, $d:ident) => {
@@ -196,6 +199,11 @@ pub fn new_group_values(
196199
DataType::Boolean => {
197200
return Ok(Box::new(GroupValuesBoolean::new()));
198201
}
202+
/*DataType::Dictionary(_, _) => {
203+
println!("dictionary type detected, using SingleDictionaryGroupValues");
204+
return Ok(Box::new(SingleDictionaryGroupValues::new()));
205+
206+
}*/
199207
_ => {}
200208
}
201209
}
@@ -207,6 +215,7 @@ pub fn new_group_values(
207215
Ok(Box::new(GroupValuesColumn::<true>::try_new(schema)?))
208216
}
209217
} else {
218+
// TODO: add specialized implementation for dictionary encoding columns for 2+ group by columns case
210219
Ok(Box::new(GroupValuesRows::try_new(schema)?))
211220
}
212221
}

datafusion/physical-plan/src/aggregates/group_values/row.rs

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ use arrow::array::{
2323
use arrow::compute::cast;
2424
use arrow::datatypes::{DataType, SchemaRef};
2525
use arrow::row::{RowConverter, Rows, SortField};
26-
use datafusion_common::Result;
2726
use datafusion_common::hash_utils::RandomState;
2827
use datafusion_common::hash_utils::create_hashes;
28+
use datafusion_common::{Result, internal_err};
2929
use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt};
3030
use datafusion_expr::EmitTo;
3131
use hashbrown::hash_table::HashTable;
@@ -324,3 +324,224 @@ fn dictionary_encode_if_necessary(
324324
(_, _) => Ok(Arc::<dyn Array>::clone(array)),
325325
}
326326
}
327+
328+
mod playground {
329+
use std::ops::Index;
330+
331+
use arrow::array::{AsArray, Datum};
332+
use datafusion_execution::TaskContext;
333+
334+
use crate::{ExecutionPlan, test::TestMemoryExec};
335+
336+
use super::*;
337+
use arrow::array::{Array, ArrayRef, StringArray};
338+
use std::string::String;
339+
struct TrivialGroupBy {
340+
seen_strings: Vec<String>,
341+
cur_size: usize,
342+
}
343+
impl GroupValues for TrivialGroupBy {
344+
// trivial apprach assume theres only one columns
345+
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
346+
let n_rows = cols[0].len();
347+
let column_count = cols.len();
348+
// iterate all the rows, for each row, concate each value into a final string
349+
for row_idx in 0..n_rows {
350+
// grab the underlying value; assume its a string
351+
let mut cur_str = String::from("");
352+
for col_idx in 0..column_count {
353+
let array = cols.get(col_idx).unwrap();
354+
let string_array =
355+
array.as_any().downcast_ref::<StringArray>().unwrap();
356+
if string_array.is_valid(row_idx) {
357+
let value = string_array.value(row_idx);
358+
cur_str.push_str(value);
359+
}
360+
}
361+
let idx = if let Some(i) =
362+
self.seen_strings.iter().position(|x| x == &cur_str)
363+
{
364+
i
365+
} else {
366+
self.seen_strings.push(cur_str);
367+
self.seen_strings.len() - 1
368+
};
369+
groups.push(idx);
370+
}
371+
println!("{:?}", self.seen_strings);
372+
Ok(())
373+
}
374+
fn len(&self) -> usize {
375+
self.seen_strings.len()
376+
}
377+
fn is_empty(&self) -> bool {
378+
self.seen_strings.is_empty()
379+
}
380+
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
381+
let strings_to_emit = match emit_to {
382+
EmitTo::All => {
383+
// take all groups, clear internal state
384+
std::mem::take(&mut self.seen_strings)
385+
}
386+
EmitTo::First(n) => {
387+
// take only the first n groups
388+
// drain removes them from seen_strings
389+
self.seen_strings.drain(..n).collect()
390+
}
391+
};
392+
393+
// convert our Vec<String> back into an Arrow StringArray
394+
let array: ArrayRef = Arc::new(StringArray::from(
395+
strings_to_emit
396+
.iter()
397+
.map(|s| s.as_str())
398+
.collect::<Vec<_>>(),
399+
));
400+
401+
// return one array per GROUP BY column
402+
// since we're trivial and only support one column, just return one
403+
Ok(vec![array])
404+
}
405+
406+
fn size(&self) -> usize {
407+
self.cur_size
408+
}
409+
fn clear_shrink(&mut self, num_rows: usize) {
410+
let _ = num_rows;
411+
}
412+
}
413+
414+
#[test]
415+
fn test_trivial_group_by_single_column() {
416+
// Test grouping on a single string column
417+
let strings = vec!["apple", "banana", "apple", "cherry", "banana"];
418+
let array: ArrayRef = Arc::new(StringArray::from(strings));
419+
420+
let mut group_by = TrivialGroupBy {
421+
seen_strings: Vec::new(),
422+
cur_size: 0,
423+
};
424+
425+
// Intern the group keys
426+
let mut groups = Vec::new();
427+
group_by.intern(&[array], &mut groups).unwrap();
428+
429+
// Should have assigned group ids: 0, 1, 0, 2, 1
430+
assert_eq!(groups, vec![0, 1, 0, 2, 1]);
431+
assert_eq!(group_by.len(), 3); // apple, banana, cherry
432+
433+
// Emit all groups
434+
let emitted = group_by.emit(EmitTo::All).unwrap();
435+
assert_eq!(emitted.len(), 1); // One column
436+
let emitted_array = emitted[0].as_any().downcast_ref::<StringArray>().unwrap();
437+
assert_eq!(emitted_array.value(0), "apple");
438+
assert_eq!(emitted_array.value(1), "banana");
439+
assert_eq!(emitted_array.value(2), "cherry");
440+
}
441+
442+
#[test]
443+
fn test_trivial_group_by_two_columns() {
444+
// Test grouping on two string columns
445+
let col1 = vec!["a", "a", "b", "a", "b"];
446+
let col2 = vec!["x", "y", "x", "x", "y"];
447+
448+
let array1: ArrayRef = Arc::new(StringArray::from(col1));
449+
let array2: ArrayRef = Arc::new(StringArray::from(col2));
450+
451+
let mut group_by = TrivialGroupBy {
452+
seen_strings: Vec::new(),
453+
cur_size: 0,
454+
};
455+
456+
// Intern: concatenates ("a" + "x"), ("a" + "y"), ("b" + "x"), etc.
457+
let mut groups = Vec::new();
458+
group_by.intern(&[array1, array2], &mut groups).unwrap();
459+
460+
// Should have 4 distinct groups: "ax", "ay", "bx", "by"
461+
assert_eq!(group_by.len(), 4);
462+
assert_eq!(groups, vec![0, 1, 2, 0, 3]); // group ids assigned
463+
464+
// Emit all groups
465+
let emitted = group_by.emit(EmitTo::All).unwrap();
466+
assert_eq!(emitted.len(), 1); // One output column (concatenated strings)
467+
let emitted_array = emitted[0].as_any().downcast_ref::<StringArray>().unwrap();
468+
assert_eq!(emitted_array.value(0), "ax");
469+
assert_eq!(emitted_array.value(1), "ay");
470+
assert_eq!(emitted_array.value(2), "bx");
471+
assert_eq!(emitted_array.value(3), "by");
472+
}
473+
474+
#[tokio::test]
475+
async fn test_trivial_group_by_dictionary() -> Result<()> {
476+
use arrow::array::DictionaryArray;
477+
use arrow::datatypes::{DataType, Field, Schema};
478+
use datafusion_functions_aggregate::count::count_udaf;
479+
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
480+
use datafusion_physical_expr::expressions::col;
481+
use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
482+
use crate::test::TestMemoryExec;
483+
use crate::aggregates::RecordBatch;
484+
use crate::common::collect;
485+
486+
// Create schema with dictionary column and value column
487+
let schema = Arc::new(Schema::new(vec![
488+
Field::new(
489+
"color",
490+
DataType::Dictionary(
491+
Box::new(DataType::UInt8),
492+
Box::new(DataType::Utf8),
493+
),
494+
false,
495+
),
496+
Field::new("amount", DataType::UInt32, false),
497+
]));
498+
499+
// Create dictionary array
500+
let values = StringArray::from(vec!["red", "blue", "green"]);
501+
let keys = arrow::array::UInt8Array::from(vec![0, 1, 0, 2, 1]);
502+
let dict_array: ArrayRef =
503+
Arc::new(DictionaryArray::<arrow::datatypes::UInt8Type>::try_new(
504+
keys,
505+
Arc::new(values),
506+
)?);
507+
508+
// Create value column
509+
let amount_array: ArrayRef =
510+
Arc::new(arrow::array::UInt32Array::from(vec![1, 2, 3, 4, 5]));
511+
512+
// Create batch
513+
let batch =
514+
RecordBatch::try_new(Arc::clone(&schema), vec![dict_array, amount_array])?;
515+
516+
// Create in-memory source with the batch
517+
let source = TestMemoryExec::try_new(&vec![vec![batch]], Arc::clone(&schema), None)?;
518+
519+
// Create GROUP BY expression
520+
let group_expr = vec![(col("color", &schema)?, "color".to_string())];
521+
522+
// Create COUNT(amount) aggregate expression
523+
let aggr_expr = vec![Arc::new(
524+
AggregateExprBuilder::new(count_udaf(), vec![col("amount", &schema)?])
525+
.schema(Arc::clone(&schema))
526+
.alias("count_amount")
527+
.build()?,
528+
)];
529+
530+
// Create AggregateExec
531+
let aggregate_exec = AggregateExec::try_new(
532+
AggregateMode::SinglePartitioned,
533+
PhysicalGroupBy::new_single(group_expr),
534+
aggr_expr,
535+
vec![None],
536+
Arc::new(source),
537+
Arc::clone(&schema),
538+
)?;
539+
540+
541+
let output =
542+
collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
543+
println!("Output batch: {:#?}", output);
544+
Ok(())
545+
}
546+
547+
}

0 commit comments

Comments
 (0)