@@ -23,9 +23,9 @@ use arrow::array::{
2323use arrow:: compute:: cast;
2424use arrow:: datatypes:: { DataType , SchemaRef } ;
2525use arrow:: row:: { RowConverter , Rows , SortField } ;
26- use datafusion_common:: Result ;
2726use datafusion_common:: hash_utils:: RandomState ;
2827use datafusion_common:: hash_utils:: create_hashes;
28+ use datafusion_common:: { Result , internal_err} ;
2929use datafusion_execution:: memory_pool:: proxy:: { HashTableAllocExt , VecAllocExt } ;
3030use datafusion_expr:: EmitTo ;
3131use hashbrown:: hash_table:: HashTable ;
@@ -324,3 +324,224 @@ fn dictionary_encode_if_necessary(
324324 ( _, _) => Ok ( Arc :: < dyn Array > :: clone ( array) ) ,
325325 }
326326}
327+
328+ mod playground {
329+ use std:: ops:: Index ;
330+
331+ use arrow:: array:: { AsArray , Datum } ;
332+ use datafusion_execution:: TaskContext ;
333+
334+ use crate :: { ExecutionPlan , test:: TestMemoryExec } ;
335+
336+ use super :: * ;
337+ use arrow:: array:: { Array , ArrayRef , StringArray } ;
338+ use std:: string:: String ;
339+ struct TrivialGroupBy {
340+ seen_strings : Vec < String > ,
341+ cur_size : usize ,
342+ }
343+ impl GroupValues for TrivialGroupBy {
344+ // trivial apprach assume theres only one columns
345+ fn intern ( & mut self , cols : & [ ArrayRef ] , groups : & mut Vec < usize > ) -> Result < ( ) > {
346+ let n_rows = cols[ 0 ] . len ( ) ;
347+ let column_count = cols. len ( ) ;
348+ // iterate all the rows, for each row, concate each value into a final string
349+ for row_idx in 0 ..n_rows {
350+ // grab the underlying value; assume its a string
351+ let mut cur_str = String :: from ( "" ) ;
352+ for col_idx in 0 ..column_count {
353+ let array = cols. get ( col_idx) . unwrap ( ) ;
354+ let string_array =
355+ array. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
356+ if string_array. is_valid ( row_idx) {
357+ let value = string_array. value ( row_idx) ;
358+ cur_str. push_str ( value) ;
359+ }
360+ }
361+ let idx = if let Some ( i) =
362+ self . seen_strings . iter ( ) . position ( |x| x == & cur_str)
363+ {
364+ i
365+ } else {
366+ self . seen_strings . push ( cur_str) ;
367+ self . seen_strings . len ( ) - 1
368+ } ;
369+ groups. push ( idx) ;
370+ }
371+ println ! ( "{:?}" , self . seen_strings) ;
372+ Ok ( ( ) )
373+ }
374+ fn len ( & self ) -> usize {
375+ self . seen_strings . len ( )
376+ }
377+ fn is_empty ( & self ) -> bool {
378+ self . seen_strings . is_empty ( )
379+ }
380+ fn emit ( & mut self , emit_to : EmitTo ) -> Result < Vec < ArrayRef > > {
381+ let strings_to_emit = match emit_to {
382+ EmitTo :: All => {
383+ // take all groups, clear internal state
384+ std:: mem:: take ( & mut self . seen_strings )
385+ }
386+ EmitTo :: First ( n) => {
387+ // take only the first n groups
388+ // drain removes them from seen_strings
389+ self . seen_strings . drain ( ..n) . collect ( )
390+ }
391+ } ;
392+
393+ // convert our Vec<String> back into an Arrow StringArray
394+ let array: ArrayRef = Arc :: new ( StringArray :: from (
395+ strings_to_emit
396+ . iter ( )
397+ . map ( |s| s. as_str ( ) )
398+ . collect :: < Vec < _ > > ( ) ,
399+ ) ) ;
400+
401+ // return one array per GROUP BY column
402+ // since we're trivial and only support one column, just return one
403+ Ok ( vec ! [ array] )
404+ }
405+
406+ fn size ( & self ) -> usize {
407+ self . cur_size
408+ }
409+ fn clear_shrink ( & mut self , num_rows : usize ) {
410+ let _ = num_rows;
411+ }
412+ }
413+
414+ #[ test]
415+ fn test_trivial_group_by_single_column ( ) {
416+ // Test grouping on a single string column
417+ let strings = vec ! [ "apple" , "banana" , "apple" , "cherry" , "banana" ] ;
418+ let array: ArrayRef = Arc :: new ( StringArray :: from ( strings) ) ;
419+
420+ let mut group_by = TrivialGroupBy {
421+ seen_strings : Vec :: new ( ) ,
422+ cur_size : 0 ,
423+ } ;
424+
425+ // Intern the group keys
426+ let mut groups = Vec :: new ( ) ;
427+ group_by. intern ( & [ array] , & mut groups) . unwrap ( ) ;
428+
429+ // Should have assigned group ids: 0, 1, 0, 2, 1
430+ assert_eq ! ( groups, vec![ 0 , 1 , 0 , 2 , 1 ] ) ;
431+ assert_eq ! ( group_by. len( ) , 3 ) ; // apple, banana, cherry
432+
433+ // Emit all groups
434+ let emitted = group_by. emit ( EmitTo :: All ) . unwrap ( ) ;
435+ assert_eq ! ( emitted. len( ) , 1 ) ; // One column
436+ let emitted_array = emitted[ 0 ] . as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
437+ assert_eq ! ( emitted_array. value( 0 ) , "apple" ) ;
438+ assert_eq ! ( emitted_array. value( 1 ) , "banana" ) ;
439+ assert_eq ! ( emitted_array. value( 2 ) , "cherry" ) ;
440+ }
441+
442+ #[ test]
443+ fn test_trivial_group_by_two_columns ( ) {
444+ // Test grouping on two string columns
445+ let col1 = vec ! [ "a" , "a" , "b" , "a" , "b" ] ;
446+ let col2 = vec ! [ "x" , "y" , "x" , "x" , "y" ] ;
447+
448+ let array1: ArrayRef = Arc :: new ( StringArray :: from ( col1) ) ;
449+ let array2: ArrayRef = Arc :: new ( StringArray :: from ( col2) ) ;
450+
451+ let mut group_by = TrivialGroupBy {
452+ seen_strings : Vec :: new ( ) ,
453+ cur_size : 0 ,
454+ } ;
455+
456+ // Intern: concatenates ("a" + "x"), ("a" + "y"), ("b" + "x"), etc.
457+ let mut groups = Vec :: new ( ) ;
458+ group_by. intern ( & [ array1, array2] , & mut groups) . unwrap ( ) ;
459+
460+ // Should have 4 distinct groups: "ax", "ay", "bx", "by"
461+ assert_eq ! ( group_by. len( ) , 4 ) ;
462+ assert_eq ! ( groups, vec![ 0 , 1 , 2 , 0 , 3 ] ) ; // group ids assigned
463+
464+ // Emit all groups
465+ let emitted = group_by. emit ( EmitTo :: All ) . unwrap ( ) ;
466+ assert_eq ! ( emitted. len( ) , 1 ) ; // One output column (concatenated strings)
467+ let emitted_array = emitted[ 0 ] . as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
468+ assert_eq ! ( emitted_array. value( 0 ) , "ax" ) ;
469+ assert_eq ! ( emitted_array. value( 1 ) , "ay" ) ;
470+ assert_eq ! ( emitted_array. value( 2 ) , "bx" ) ;
471+ assert_eq ! ( emitted_array. value( 3 ) , "by" ) ;
472+ }
473+
474+ #[ tokio:: test]
475+ async fn test_trivial_group_by_dictionary ( ) -> Result < ( ) > {
476+ use arrow:: array:: DictionaryArray ;
477+ use arrow:: datatypes:: { DataType , Field , Schema } ;
478+ use datafusion_functions_aggregate:: count:: count_udaf;
479+ use datafusion_physical_expr:: aggregate:: AggregateExprBuilder ;
480+ use datafusion_physical_expr:: expressions:: col;
481+ use crate :: aggregates:: { AggregateExec , AggregateMode , PhysicalGroupBy } ;
482+ use crate :: test:: TestMemoryExec ;
483+ use crate :: aggregates:: RecordBatch ;
484+ use crate :: common:: collect;
485+
486+ // Create schema with dictionary column and value column
487+ let schema = Arc :: new ( Schema :: new ( vec ! [
488+ Field :: new(
489+ "color" ,
490+ DataType :: Dictionary (
491+ Box :: new( DataType :: UInt8 ) ,
492+ Box :: new( DataType :: Utf8 ) ,
493+ ) ,
494+ false ,
495+ ) ,
496+ Field :: new( "amount" , DataType :: UInt32 , false ) ,
497+ ] ) ) ;
498+
499+ // Create dictionary array
500+ let values = StringArray :: from ( vec ! [ "red" , "blue" , "green" ] ) ;
501+ let keys = arrow:: array:: UInt8Array :: from ( vec ! [ 0 , 1 , 0 , 2 , 1 ] ) ;
502+ let dict_array: ArrayRef =
503+ Arc :: new ( DictionaryArray :: < arrow:: datatypes:: UInt8Type > :: try_new (
504+ keys,
505+ Arc :: new ( values) ,
506+ ) ?) ;
507+
508+ // Create value column
509+ let amount_array: ArrayRef =
510+ Arc :: new ( arrow:: array:: UInt32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
511+
512+ // Create batch
513+ let batch =
514+ RecordBatch :: try_new ( Arc :: clone ( & schema) , vec ! [ dict_array, amount_array] ) ?;
515+
516+ // Create in-memory source with the batch
517+ let source = TestMemoryExec :: try_new ( & vec ! [ vec![ batch] ] , Arc :: clone ( & schema) , None ) ?;
518+
519+ // Create GROUP BY expression
520+ let group_expr = vec ! [ ( col( "color" , & schema) ?, "color" . to_string( ) ) ] ;
521+
522+ // Create COUNT(amount) aggregate expression
523+ let aggr_expr = vec ! [ Arc :: new(
524+ AggregateExprBuilder :: new( count_udaf( ) , vec![ col( "amount" , & schema) ?] )
525+ . schema( Arc :: clone( & schema) )
526+ . alias( "count_amount" )
527+ . build( ) ?,
528+ ) ] ;
529+
530+ // Create AggregateExec
531+ let aggregate_exec = AggregateExec :: try_new (
532+ AggregateMode :: SinglePartitioned ,
533+ PhysicalGroupBy :: new_single ( group_expr) ,
534+ aggr_expr,
535+ vec ! [ None ] ,
536+ Arc :: new ( source) ,
537+ Arc :: clone ( & schema) ,
538+ ) ?;
539+
540+
541+ let output =
542+ collect ( aggregate_exec. execute ( 0 , Arc :: new ( TaskContext :: default ( ) ) ) ?) . await ?;
543+ println ! ( "Output batch: {:#?}" , output) ;
544+ Ok ( ( ) )
545+ }
546+
547+ }
0 commit comments