1818use crate :: aggregates:: group_values:: GroupValues ;
1919use crate :: hash_utils:: RandomState ;
2020use arrow:: array:: {
21- Array , ArrayRef , BinaryArray , BinaryBuilder , BinaryViewArray , BinaryViewBuilder ,
22- DictionaryArray , LargeBinaryArray , LargeBinaryBuilder , LargeStringArray ,
23- LargeStringBuilder , PrimitiveArray , PrimitiveBuilder , StringArray , StringBuilder ,
24- StringViewArray , StringViewBuilder , UInt64Array ,
21+ Array , ArrayRef , AsArray , BinaryArray , BinaryBuilder , BinaryViewArray ,
22+ BinaryViewBuilder , DictionaryArray , LargeBinaryArray , LargeBinaryBuilder ,
23+ LargeStringArray , LargeStringBuilder , PrimitiveArray , PrimitiveBuilder , StringArray ,
24+ StringBuilder , StringViewArray , StringViewBuilder , UInt64Array ,
25+ } ;
26+ use arrow:: datatypes:: {
27+ ArrowDictionaryKeyType , ArrowNativeType , DataType , Int8Type , Int16Type , Int32Type ,
28+ Int64Type , UInt8Type , UInt16Type , UInt32Type , UInt64Type ,
2529} ;
26- use arrow:: datatypes:: { ArrowDictionaryKeyType , ArrowNativeType , DataType } ;
2730use datafusion_common:: Result ;
2831use datafusion_common:: hash_utils:: create_hashes;
2932use datafusion_expr:: EmitTo ;
@@ -129,6 +132,46 @@ impl<K: ArrowDictionaryKeyType + Send> GroupValuesDictionary<K> {
129132 . downcast_ref :: < BinaryViewArray > ( )
130133 . expect ( "Expected BinaryViewArray" )
131134 . value ( index) ,
135+ DataType :: Int8 => {
136+ let arr = values. as_primitive :: < Int8Type > ( ) ;
137+ let val = arr. value ( index) ;
138+ unsafe { std:: slice:: from_raw_parts ( & val as * const i8 as * const u8 , 1 ) }
139+ }
140+ DataType :: Int16 => {
141+ let arr = values. as_primitive :: < Int16Type > ( ) ;
142+ let val = arr. value ( index) ;
143+ unsafe { std:: slice:: from_raw_parts ( & val as * const i16 as * const u8 , 2 ) }
144+ }
145+ DataType :: Int32 => {
146+ let arr = values. as_primitive :: < Int32Type > ( ) ;
147+ let val = arr. value ( index) ;
148+ unsafe { std:: slice:: from_raw_parts ( & val as * const i32 as * const u8 , 4 ) }
149+ }
150+ DataType :: Int64 => {
151+ let arr = values. as_primitive :: < Int64Type > ( ) ;
152+ let val = arr. value ( index) ;
153+ unsafe { std:: slice:: from_raw_parts ( & val as * const i64 as * const u8 , 8 ) }
154+ }
155+ DataType :: UInt8 => {
156+ let arr = values. as_primitive :: < UInt8Type > ( ) ;
157+ let val = arr. value ( index) ;
158+ unsafe { std:: slice:: from_raw_parts ( & val as * const u8 , 1 ) }
159+ }
160+ DataType :: UInt16 => {
161+ let arr = values. as_primitive :: < UInt16Type > ( ) ;
162+ let val = arr. value ( index) ;
163+ unsafe { std:: slice:: from_raw_parts ( & val as * const u16 as * const u8 , 2 ) }
164+ }
165+ DataType :: UInt32 => {
166+ let arr = values. as_primitive :: < UInt32Type > ( ) ;
167+ let val = arr. value ( index) ;
168+ unsafe { std:: slice:: from_raw_parts ( & val as * const u32 as * const u8 , 4 ) }
169+ }
170+ DataType :: UInt64 => {
171+ let arr = values. as_primitive :: < UInt64Type > ( ) ;
172+ let val = arr. value ( index) ;
173+ unsafe { std:: slice:: from_raw_parts ( & val as * const u64 as * const u8 , 8 ) }
174+ }
132175 other => unimplemented ! ( "get_raw_bytes not implemented for {other:?}" ) ,
133176 }
134177 }
@@ -145,6 +188,15 @@ impl<K: ArrowDictionaryKeyType + Send> GroupValuesDictionary<K> {
145188 }
146189 // for primitives use a byte sequence that is a different length than the native type
147190 // a real i8 is always exactly 1 byte so 2 bytes can never match a real value
191+ DataType :: Int8 | DataType :: UInt8 => vec ! [ 0xFF , 0xFF ] ,
192+ // a real i16/u16 is always exactly 2 bytes so 3 bytes can never match
193+ DataType :: Int16 | DataType :: UInt16 => vec ! [ 0xFF , 0xFF , 0xFF ] ,
194+ // a real i32/u32/f32 is always exactly 4 bytes so 5 bytes can never match
195+ DataType :: Int32 | DataType :: UInt32 => vec ! [ 0xFF , 0xFF , 0xFF , 0xFF , 0xFF ] ,
196+ // a real i64/u64/f64 is always exactly 8 bytes so 9 bytes can never match
197+ DataType :: Int64 | DataType :: UInt64 => {
198+ vec ! [ 0xFF , 0xFF , 0xFF , 0xFF , 0xFF , 0xFF , 0xFF , 0xFF , 0xFF ]
199+ }
148200 other => unimplemented ! ( "sentinel_repr not implemented for {other:?}" ) ,
149201 }
150202 }
@@ -255,6 +307,110 @@ impl<K: ArrowDictionaryKeyType + Send> GroupValuesDictionary<K> {
255307 }
256308 Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
257309 }
310+ DataType :: Int8 => {
311+ let mut builder = PrimitiveBuilder :: < Int8Type > :: new ( ) ;
312+ for raw_bytes in raw {
313+ if raw_bytes == & sentinel {
314+ builder. append_null ( ) ;
315+ } else {
316+ builder. append_value ( i8:: from_ne_bytes (
317+ raw_bytes. as_slice ( ) . try_into ( ) . unwrap ( ) ,
318+ ) ) ;
319+ }
320+ }
321+ Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
322+ }
323+ DataType :: Int16 => {
324+ let mut builder = PrimitiveBuilder :: < Int16Type > :: new ( ) ;
325+ for raw_bytes in raw {
326+ if raw_bytes == & sentinel {
327+ builder. append_null ( ) ;
328+ } else {
329+ builder. append_value ( i16:: from_ne_bytes (
330+ raw_bytes. as_slice ( ) . try_into ( ) . unwrap ( ) ,
331+ ) ) ;
332+ }
333+ }
334+ Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
335+ }
336+ DataType :: Int32 => {
337+ let mut builder = PrimitiveBuilder :: < Int32Type > :: new ( ) ;
338+ for raw_bytes in raw {
339+ if raw_bytes == & sentinel {
340+ builder. append_null ( ) ;
341+ } else {
342+ builder. append_value ( i32:: from_ne_bytes (
343+ raw_bytes. as_slice ( ) . try_into ( ) . unwrap ( ) ,
344+ ) ) ;
345+ }
346+ }
347+ Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
348+ }
349+ DataType :: Int64 => {
350+ let mut builder = PrimitiveBuilder :: < Int64Type > :: new ( ) ;
351+ for raw_bytes in raw {
352+ if raw_bytes == & sentinel {
353+ builder. append_null ( ) ;
354+ } else {
355+ builder. append_value ( i64:: from_ne_bytes (
356+ raw_bytes. as_slice ( ) . try_into ( ) . unwrap ( ) ,
357+ ) ) ;
358+ }
359+ }
360+ Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
361+ }
362+ DataType :: UInt8 => {
363+ let mut builder = PrimitiveBuilder :: < UInt8Type > :: new ( ) ;
364+ for raw_bytes in raw {
365+ if raw_bytes == & sentinel {
366+ builder. append_null ( ) ;
367+ } else {
368+ builder. append_value ( u8:: from_ne_bytes (
369+ raw_bytes. as_slice ( ) . try_into ( ) . unwrap ( ) ,
370+ ) ) ;
371+ }
372+ }
373+ Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
374+ }
375+ DataType :: UInt16 => {
376+ let mut builder = PrimitiveBuilder :: < UInt16Type > :: new ( ) ;
377+ for raw_bytes in raw {
378+ if raw_bytes == & sentinel {
379+ builder. append_null ( ) ;
380+ } else {
381+ builder. append_value ( u16:: from_ne_bytes (
382+ raw_bytes. as_slice ( ) . try_into ( ) . unwrap ( ) ,
383+ ) ) ;
384+ }
385+ }
386+ Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
387+ }
388+ DataType :: UInt32 => {
389+ let mut builder = PrimitiveBuilder :: < UInt32Type > :: new ( ) ;
390+ for raw_bytes in raw {
391+ if raw_bytes == & sentinel {
392+ builder. append_null ( ) ;
393+ } else {
394+ builder. append_value ( u32:: from_ne_bytes (
395+ raw_bytes. as_slice ( ) . try_into ( ) . unwrap ( ) ,
396+ ) ) ;
397+ }
398+ }
399+ Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
400+ }
401+ DataType :: UInt64 => {
402+ let mut builder = PrimitiveBuilder :: < UInt64Type > :: new ( ) ;
403+ for raw_bytes in raw {
404+ if raw_bytes == & sentinel {
405+ builder. append_null ( ) ;
406+ } else {
407+ builder. append_value ( u64:: from_ne_bytes (
408+ raw_bytes. as_slice ( ) . try_into ( ) . unwrap ( ) ,
409+ ) ) ;
410+ }
411+ }
412+ Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
413+ }
258414 other => Err ( datafusion_common:: DataFusionError :: NotImplemented ( format ! (
259415 "transform_into_array not implemented for {other:?}"
260416 ) ) ) ,
@@ -482,8 +638,7 @@ impl<K: ArrowDictionaryKeyType + Send> GroupValues for GroupValuesDictionary<K>
482638#[ cfg( test) ]
483639mod group_values_trait_test {
484640 use super :: * ;
485- use arrow:: array:: { DictionaryArray , Int32Array , StringArray , UInt8Array } ;
486- use arrow:: datatypes:: { Int32Type , UInt8Type } ;
641+ use arrow:: array:: { DictionaryArray , StringArray , UInt8Array } ;
487642 use std:: sync:: Arc ;
488643
489644 fn create_dict_array ( keys : Vec < u8 > , values : Vec < & str > ) -> ArrayRef {
@@ -1119,6 +1274,7 @@ mod group_values_trait_test {
11191274
11201275 mod data_correctness {
11211276 use super :: * ;
1277+ use arrow:: array:: Int32Array ;
11221278
11231279 #[ test]
11241280 fn test_group_assignment_order ( ) {
0 commit comments