@@ -23,9 +23,7 @@ use std::hash::Hash;
2323use std:: mem:: size_of_val;
2424use std:: sync:: Arc ;
2525
26- use arrow:: array:: {
27- Array , ArrayRef , ArrowPrimitiveType , AsArray , BooleanArray , BooleanBufferBuilder ,
28- } ;
26+ use arrow:: array:: { Array , ArrayRef , AsArray , BooleanArray , BooleanBufferBuilder } ;
2927use arrow:: buffer:: BooleanBuffer ;
3028use arrow:: compute:: { self , LexicographicalComparator , SortColumn , SortOptions } ;
3129use arrow:: datatypes:: {
@@ -79,31 +77,10 @@ pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
7977 . unwrap ( )
8078}
8179
82- fn create_groups_primitive_accumulator < T : ArrowPrimitiveType + Send > (
83- args : & AccumulatorArgs ,
84- is_first : bool ,
85- ) -> Result < Box < dyn GroupsAccumulator > > {
86- let Some ( ordering) = LexOrdering :: new ( args. order_bys . to_vec ( ) ) else {
87- return internal_err ! ( "Groups accumulator must have an ordering." ) ;
88- } ;
89-
90- let ordering_dtypes = ordering
91- . iter ( )
92- . map ( |e| e. expr . data_type ( args. schema ) )
93- . collect :: < Result < Vec < _ > > > ( ) ?;
94-
95- Ok ( Box :: new ( FirstLastGroupsAccumulator :: try_new (
96- PrimitiveValueState :: < T > :: new ( args. return_field . data_type ( ) . clone ( ) ) ,
97- ordering,
98- args. ignore_nulls ,
99- & ordering_dtypes,
100- is_first,
101- ) ?) )
102- }
103-
104- fn create_groups_bytes_accumulator (
80+ fn create_groups_accumulator_helper < S : ValueState + ' static > (
10581 args : & AccumulatorArgs ,
10682 is_first : bool ,
83+ state : S ,
10784) -> Result < Box < dyn GroupsAccumulator > > {
10885 let Some ( ordering) = LexOrdering :: new ( args. order_bys . to_vec ( ) ) else {
10986 return internal_err ! ( "Groups accumulator must have an ordering." ) ;
@@ -115,7 +92,7 @@ fn create_groups_bytes_accumulator(
11592 . collect :: < Result < Vec < _ > > > ( ) ?;
11693
11794 Ok ( Box :: new ( FirstLastGroupsAccumulator :: try_new (
118- BytesValueState :: try_new ( args . return_field . data_type ( ) . clone ( ) ) ? ,
95+ state ,
11996 ordering,
12097 args. ignore_nulls ,
12198 & ordering_dtypes,
@@ -128,99 +105,77 @@ fn create_groups_accumulator(
128105 is_first : bool ,
129106 function_name : & str ,
130107) -> Result < Box < dyn GroupsAccumulator > > {
131- match args. return_field . data_type ( ) {
132- DataType :: Int8 => create_groups_primitive_accumulator :: < Int8Type > ( args, is_first) ,
133- DataType :: Int16 => {
134- create_groups_primitive_accumulator :: < Int16Type > ( args, is_first)
135- }
136- DataType :: Int32 => {
137- create_groups_primitive_accumulator :: < Int32Type > ( args, is_first)
138- }
139- DataType :: Int64 => {
140- create_groups_primitive_accumulator :: < Int64Type > ( args, is_first)
141- }
142- DataType :: UInt8 => {
143- create_groups_primitive_accumulator :: < UInt8Type > ( args, is_first)
144- }
145- DataType :: UInt16 => {
146- create_groups_primitive_accumulator :: < UInt16Type > ( args, is_first)
147- }
148- DataType :: UInt32 => {
149- create_groups_primitive_accumulator :: < UInt32Type > ( args, is_first)
150- }
151- DataType :: UInt64 => {
152- create_groups_primitive_accumulator :: < UInt64Type > ( args, is_first)
153- }
154- DataType :: Float16 => {
155- create_groups_primitive_accumulator :: < Float16Type > ( args, is_first)
156- }
157- DataType :: Float32 => {
158- create_groups_primitive_accumulator :: < Float32Type > ( args, is_first)
159- }
160- DataType :: Float64 => {
161- create_groups_primitive_accumulator :: < Float64Type > ( args, is_first)
162- }
108+ let data_type = args. return_field . data_type ( ) ;
109+
110+ macro_rules! instantiate_primitive {
111+ ( $t: ty) => {
112+ create_groups_accumulator_helper(
113+ args,
114+ is_first,
115+ PrimitiveValueState :: <$t>:: new( data_type. clone( ) ) ,
116+ )
117+ } ;
118+ }
163119
164- DataType :: Decimal32 ( _, _) => {
165- create_groups_primitive_accumulator :: < Decimal32Type > ( args, is_first)
166- }
167- DataType :: Decimal64 ( _, _) => {
168- create_groups_primitive_accumulator :: < Decimal64Type > ( args, is_first)
169- }
170- DataType :: Decimal128 ( _, _) => {
171- create_groups_primitive_accumulator :: < Decimal128Type > ( args, is_first)
172- }
173- DataType :: Decimal256 ( _, _) => {
174- create_groups_primitive_accumulator :: < Decimal256Type > ( args, is_first)
175- }
120+ match data_type {
121+ DataType :: Int8 => instantiate_primitive ! ( Int8Type ) ,
122+ DataType :: Int16 => instantiate_primitive ! ( Int16Type ) ,
123+ DataType :: Int32 => instantiate_primitive ! ( Int32Type ) ,
124+ DataType :: Int64 => instantiate_primitive ! ( Int64Type ) ,
125+ DataType :: UInt8 => instantiate_primitive ! ( UInt8Type ) ,
126+ DataType :: UInt16 => instantiate_primitive ! ( UInt16Type ) ,
127+ DataType :: UInt32 => instantiate_primitive ! ( UInt32Type ) ,
128+ DataType :: UInt64 => instantiate_primitive ! ( UInt64Type ) ,
129+ DataType :: Float16 => instantiate_primitive ! ( Float16Type ) ,
130+ DataType :: Float32 => instantiate_primitive ! ( Float32Type ) ,
131+ DataType :: Float64 => instantiate_primitive ! ( Float64Type ) ,
132+
133+ DataType :: Decimal32 ( _, _) => instantiate_primitive ! ( Decimal32Type ) ,
134+ DataType :: Decimal64 ( _, _) => instantiate_primitive ! ( Decimal64Type ) ,
135+ DataType :: Decimal128 ( _, _) => instantiate_primitive ! ( Decimal128Type ) ,
136+ DataType :: Decimal256 ( _, _) => instantiate_primitive ! ( Decimal256Type ) ,
176137
177138 DataType :: Timestamp ( TimeUnit :: Second , _) => {
178- create_groups_primitive_accumulator :: < TimestampSecondType > ( args , is_first )
139+ instantiate_primitive ! ( TimestampSecondType )
179140 }
180141 DataType :: Timestamp ( TimeUnit :: Millisecond , _) => {
181- create_groups_primitive_accumulator :: < TimestampMillisecondType > (
182- args, is_first,
183- )
142+ instantiate_primitive ! ( TimestampMillisecondType )
184143 }
185144 DataType :: Timestamp ( TimeUnit :: Microsecond , _) => {
186- create_groups_primitive_accumulator :: < TimestampMicrosecondType > (
187- args, is_first,
188- )
145+ instantiate_primitive ! ( TimestampMicrosecondType )
189146 }
190147 DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => {
191- create_groups_primitive_accumulator :: < TimestampNanosecondType > ( args , is_first )
148+ instantiate_primitive ! ( TimestampNanosecondType )
192149 }
193150
194- DataType :: Date32 => {
195- create_groups_primitive_accumulator :: < Date32Type > ( args, is_first)
196- }
197- DataType :: Date64 => {
198- create_groups_primitive_accumulator :: < Date64Type > ( args, is_first)
199- }
200- DataType :: Time32 ( TimeUnit :: Second ) => {
201- create_groups_primitive_accumulator :: < Time32SecondType > ( args, is_first)
202- }
151+ DataType :: Date32 => instantiate_primitive ! ( Date32Type ) ,
152+ DataType :: Date64 => instantiate_primitive ! ( Date64Type ) ,
153+ DataType :: Time32 ( TimeUnit :: Second ) => instantiate_primitive ! ( Time32SecondType ) ,
203154 DataType :: Time32 ( TimeUnit :: Millisecond ) => {
204- create_groups_primitive_accumulator :: < Time32MillisecondType > ( args , is_first )
155+ instantiate_primitive ! ( Time32MillisecondType )
205156 }
206157 DataType :: Time64 ( TimeUnit :: Microsecond ) => {
207- create_groups_primitive_accumulator :: < Time64MicrosecondType > ( args , is_first )
158+ instantiate_primitive ! ( Time64MicrosecondType )
208159 }
209160 DataType :: Time64 ( TimeUnit :: Nanosecond ) => {
210- create_groups_primitive_accumulator :: < Time64NanosecondType > ( args , is_first )
161+ instantiate_primitive ! ( Time64NanosecondType )
211162 }
212163
213164 DataType :: Utf8
214165 | DataType :: LargeUtf8
215166 | DataType :: Utf8View
216167 | DataType :: Binary
217168 | DataType :: LargeBinary
218- | DataType :: BinaryView => create_groups_bytes_accumulator ( args, is_first) ,
169+ | DataType :: BinaryView => create_groups_accumulator_helper (
170+ args,
171+ is_first,
172+ BytesValueState :: try_new ( data_type. clone ( ) ) ?,
173+ ) ,
219174
220175 _ => internal_err ! (
221176 "GroupsAccumulator not supported for {}({})" ,
222177 function_name,
223- args . return_field . data_type( )
178+ data_type
224179 ) ,
225180 }
226181}
@@ -419,7 +374,7 @@ struct FirstLastGroupsAccumulator<S: ValueState> {
419374 // to avoid calling `ScalarValue::size_of_vec` by Self.size.
420375 size_of_orderings : usize ,
421376
422- // buffer for `get_filtered_min_of_each_group `
377+ // buffer for `get_filtered_extreme_of_each_group `
423378 // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
424379 // only valid if filter_min_of_each_group_buf.1[group_idx] == true
425380 extreme_of_each_group_buf : ( Vec < usize > , BooleanBufferBuilder ) ,
0 commit comments