@@ -24,7 +24,7 @@ use std::mem::size_of_val;
2424use std:: sync:: Arc ;
2525
2626use arrow:: array:: {
27- Array , ArrayRef , ArrowPrimitiveType , AsArray , BooleanArray , BooleanBufferBuilder ,
27+ Array , ArrayRef , AsArray , BooleanArray , BooleanBufferBuilder ,
2828} ;
2929use arrow:: buffer:: BooleanBuffer ;
3030use arrow:: compute:: { self , LexicographicalComparator , SortColumn , SortOptions } ;
@@ -79,31 +79,10 @@ pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr {
7979 . unwrap ( )
8080}
8181
82- fn create_groups_primitive_accumulator < T : ArrowPrimitiveType + Send > (
83- args : & AccumulatorArgs ,
84- is_first : bool ,
85- ) -> Result < Box < dyn GroupsAccumulator > > {
86- let Some ( ordering) = LexOrdering :: new ( args. order_bys . to_vec ( ) ) else {
87- return internal_err ! ( "Groups accumulator must have an ordering." ) ;
88- } ;
89-
90- let ordering_dtypes = ordering
91- . iter ( )
92- . map ( |e| e. expr . data_type ( args. schema ) )
93- . collect :: < Result < Vec < _ > > > ( ) ?;
94-
95- Ok ( Box :: new ( FirstLastGroupsAccumulator :: try_new (
96- PrimitiveValueState :: < T > :: new ( args. return_field . data_type ( ) . clone ( ) ) ,
97- ordering,
98- args. ignore_nulls ,
99- & ordering_dtypes,
100- is_first,
101- ) ?) )
102- }
103-
104- fn create_groups_bytes_accumulator (
82+ fn create_groups_accumulator_helper < S : ValueState + ' static > (
10583 args : & AccumulatorArgs ,
10684 is_first : bool ,
85+ state : S ,
10786) -> Result < Box < dyn GroupsAccumulator > > {
10887 let Some ( ordering) = LexOrdering :: new ( args. order_bys . to_vec ( ) ) else {
10988 return internal_err ! ( "Groups accumulator must have an ordering." ) ;
@@ -115,7 +94,7 @@ fn create_groups_bytes_accumulator(
11594 . collect :: < Result < Vec < _ > > > ( ) ?;
11695
11796 Ok ( Box :: new ( FirstLastGroupsAccumulator :: try_new (
118- BytesValueState :: try_new ( args . return_field . data_type ( ) . clone ( ) ) ? ,
97+ state ,
11998 ordering,
12099 args. ignore_nulls ,
121100 & ordering_dtypes,
@@ -128,99 +107,152 @@ fn create_groups_accumulator(
128107 is_first : bool ,
129108 function_name : & str ,
130109) -> Result < Box < dyn GroupsAccumulator > > {
131- match args. return_field . data_type ( ) {
132- DataType :: Int8 => create_groups_primitive_accumulator :: < Int8Type > ( args, is_first) ,
133- DataType :: Int16 => {
134- create_groups_primitive_accumulator :: < Int16Type > ( args, is_first)
135- }
136- DataType :: Int32 => {
137- create_groups_primitive_accumulator :: < Int32Type > ( args, is_first)
138- }
139- DataType :: Int64 => {
140- create_groups_primitive_accumulator :: < Int64Type > ( args, is_first)
141- }
142- DataType :: UInt8 => {
143- create_groups_primitive_accumulator :: < UInt8Type > ( args, is_first)
144- }
145- DataType :: UInt16 => {
146- create_groups_primitive_accumulator :: < UInt16Type > ( args, is_first)
147- }
148- DataType :: UInt32 => {
149- create_groups_primitive_accumulator :: < UInt32Type > ( args, is_first)
150- }
151- DataType :: UInt64 => {
152- create_groups_primitive_accumulator :: < UInt64Type > ( args, is_first)
153- }
154- DataType :: Float16 => {
155- create_groups_primitive_accumulator :: < Float16Type > ( args, is_first)
156- }
157- DataType :: Float32 => {
158- create_groups_primitive_accumulator :: < Float32Type > ( args, is_first)
159- }
160- DataType :: Float64 => {
161- create_groups_primitive_accumulator :: < Float64Type > ( args, is_first)
162- }
110+ let data_type = args. return_field . data_type ( ) ;
111+ match data_type {
112+ DataType :: Int8 => create_groups_accumulator_helper (
113+ args,
114+ is_first,
115+ PrimitiveValueState :: < Int8Type > :: new ( data_type. clone ( ) ) ,
116+ ) ,
117+ DataType :: Int16 => create_groups_accumulator_helper (
118+ args,
119+ is_first,
120+ PrimitiveValueState :: < Int16Type > :: new ( data_type. clone ( ) ) ,
121+ ) ,
122+ DataType :: Int32 => create_groups_accumulator_helper (
123+ args,
124+ is_first,
125+ PrimitiveValueState :: < Int32Type > :: new ( data_type. clone ( ) ) ,
126+ ) ,
127+ DataType :: Int64 => create_groups_accumulator_helper (
128+ args,
129+ is_first,
130+ PrimitiveValueState :: < Int64Type > :: new ( data_type. clone ( ) ) ,
131+ ) ,
132+ DataType :: UInt8 => create_groups_accumulator_helper (
133+ args,
134+ is_first,
135+ PrimitiveValueState :: < UInt8Type > :: new ( data_type. clone ( ) ) ,
136+ ) ,
137+ DataType :: UInt16 => create_groups_accumulator_helper (
138+ args,
139+ is_first,
140+ PrimitiveValueState :: < UInt16Type > :: new ( data_type. clone ( ) ) ,
141+ ) ,
142+ DataType :: UInt32 => create_groups_accumulator_helper (
143+ args,
144+ is_first,
145+ PrimitiveValueState :: < UInt32Type > :: new ( data_type. clone ( ) ) ,
146+ ) ,
147+ DataType :: UInt64 => create_groups_accumulator_helper (
148+ args,
149+ is_first,
150+ PrimitiveValueState :: < UInt64Type > :: new ( data_type. clone ( ) ) ,
151+ ) ,
152+ DataType :: Float16 => create_groups_accumulator_helper (
153+ args,
154+ is_first,
155+ PrimitiveValueState :: < Float16Type > :: new ( data_type. clone ( ) ) ,
156+ ) ,
157+ DataType :: Float32 => create_groups_accumulator_helper (
158+ args,
159+ is_first,
160+ PrimitiveValueState :: < Float32Type > :: new ( data_type. clone ( ) ) ,
161+ ) ,
162+ DataType :: Float64 => create_groups_accumulator_helper (
163+ args,
164+ is_first,
165+ PrimitiveValueState :: < Float64Type > :: new ( data_type. clone ( ) ) ,
166+ ) ,
163167
164- DataType :: Decimal32 ( _, _) => {
165- create_groups_primitive_accumulator :: < Decimal32Type > ( args, is_first)
166- }
167- DataType :: Decimal64 ( _, _) => {
168- create_groups_primitive_accumulator :: < Decimal64Type > ( args, is_first)
169- }
170- DataType :: Decimal128 ( _, _) => {
171- create_groups_primitive_accumulator :: < Decimal128Type > ( args, is_first)
172- }
173- DataType :: Decimal256 ( _, _) => {
174- create_groups_primitive_accumulator :: < Decimal256Type > ( args, is_first)
175- }
168+ DataType :: Decimal32 ( _, _) => create_groups_accumulator_helper (
169+ args,
170+ is_first,
171+ PrimitiveValueState :: < Decimal32Type > :: new ( data_type. clone ( ) ) ,
172+ ) ,
173+ DataType :: Decimal64 ( _, _) => create_groups_accumulator_helper (
174+ args,
175+ is_first,
176+ PrimitiveValueState :: < Decimal64Type > :: new ( data_type. clone ( ) ) ,
177+ ) ,
178+ DataType :: Decimal128 ( _, _) => create_groups_accumulator_helper (
179+ args,
180+ is_first,
181+ PrimitiveValueState :: < Decimal128Type > :: new ( data_type. clone ( ) ) ,
182+ ) ,
183+ DataType :: Decimal256 ( _, _) => create_groups_accumulator_helper (
184+ args,
185+ is_first,
186+ PrimitiveValueState :: < Decimal256Type > :: new ( data_type. clone ( ) ) ,
187+ ) ,
176188
177- DataType :: Timestamp ( TimeUnit :: Second , _) => {
178- create_groups_primitive_accumulator :: < TimestampSecondType > ( args, is_first)
179- }
180- DataType :: Timestamp ( TimeUnit :: Millisecond , _) => {
181- create_groups_primitive_accumulator :: < TimestampMillisecondType > (
182- args, is_first,
183- )
184- }
185- DataType :: Timestamp ( TimeUnit :: Microsecond , _) => {
186- create_groups_primitive_accumulator :: < TimestampMicrosecondType > (
187- args, is_first,
188- )
189- }
190- DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => {
191- create_groups_primitive_accumulator :: < TimestampNanosecondType > ( args, is_first)
192- }
189+ DataType :: Timestamp ( TimeUnit :: Second , _) => create_groups_accumulator_helper (
190+ args,
191+ is_first,
192+ PrimitiveValueState :: < TimestampSecondType > :: new ( data_type. clone ( ) ) ,
193+ ) ,
194+ DataType :: Timestamp ( TimeUnit :: Millisecond , _) => create_groups_accumulator_helper (
195+ args,
196+ is_first,
197+ PrimitiveValueState :: < TimestampMillisecondType > :: new ( data_type. clone ( ) ) ,
198+ ) ,
199+ DataType :: Timestamp ( TimeUnit :: Microsecond , _) => create_groups_accumulator_helper (
200+ args,
201+ is_first,
202+ PrimitiveValueState :: < TimestampMicrosecondType > :: new ( data_type. clone ( ) ) ,
203+ ) ,
204+ DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => create_groups_accumulator_helper (
205+ args,
206+ is_first,
207+ PrimitiveValueState :: < TimestampNanosecondType > :: new ( data_type. clone ( ) ) ,
208+ ) ,
193209
194- DataType :: Date32 => {
195- create_groups_primitive_accumulator :: < Date32Type > ( args, is_first)
196- }
197- DataType :: Date64 => {
198- create_groups_primitive_accumulator :: < Date64Type > ( args, is_first)
199- }
200- DataType :: Time32 ( TimeUnit :: Second ) => {
201- create_groups_primitive_accumulator :: < Time32SecondType > ( args, is_first)
202- }
203- DataType :: Time32 ( TimeUnit :: Millisecond ) => {
204- create_groups_primitive_accumulator :: < Time32MillisecondType > ( args, is_first)
205- }
206- DataType :: Time64 ( TimeUnit :: Microsecond ) => {
207- create_groups_primitive_accumulator :: < Time64MicrosecondType > ( args, is_first)
208- }
209- DataType :: Time64 ( TimeUnit :: Nanosecond ) => {
210- create_groups_primitive_accumulator :: < Time64NanosecondType > ( args, is_first)
211- }
210+ DataType :: Date32 => create_groups_accumulator_helper (
211+ args,
212+ is_first,
213+ PrimitiveValueState :: < Date32Type > :: new ( data_type. clone ( ) ) ,
214+ ) ,
215+ DataType :: Date64 => create_groups_accumulator_helper (
216+ args,
217+ is_first,
218+ PrimitiveValueState :: < Date64Type > :: new ( data_type. clone ( ) ) ,
219+ ) ,
220+ DataType :: Time32 ( TimeUnit :: Second ) => create_groups_accumulator_helper (
221+ args,
222+ is_first,
223+ PrimitiveValueState :: < Time32SecondType > :: new ( data_type. clone ( ) ) ,
224+ ) ,
225+ DataType :: Time32 ( TimeUnit :: Millisecond ) => create_groups_accumulator_helper (
226+ args,
227+ is_first,
228+ PrimitiveValueState :: < Time32MillisecondType > :: new ( data_type. clone ( ) ) ,
229+ ) ,
230+ DataType :: Time64 ( TimeUnit :: Microsecond ) => create_groups_accumulator_helper (
231+ args,
232+ is_first,
233+ PrimitiveValueState :: < Time64MicrosecondType > :: new ( data_type. clone ( ) ) ,
234+ ) ,
235+ DataType :: Time64 ( TimeUnit :: Nanosecond ) => create_groups_accumulator_helper (
236+ args,
237+ is_first,
238+ PrimitiveValueState :: < Time64NanosecondType > :: new ( data_type. clone ( ) ) ,
239+ ) ,
212240
213241 DataType :: Utf8
214242 | DataType :: LargeUtf8
215243 | DataType :: Utf8View
216244 | DataType :: Binary
217245 | DataType :: LargeBinary
218- | DataType :: BinaryView => create_groups_bytes_accumulator ( args, is_first) ,
246+ | DataType :: BinaryView => create_groups_accumulator_helper (
247+ args,
248+ is_first,
249+ BytesValueState :: try_new ( data_type. clone ( ) ) ?,
250+ ) ,
219251
220252 _ => internal_err ! (
221253 "GroupsAccumulator not supported for {}({})" ,
222254 function_name,
223- args . return_field . data_type( )
255+ data_type
224256 ) ,
225257 }
226258}
@@ -419,7 +451,7 @@ struct FirstLastGroupsAccumulator<S: ValueState> {
419451 // to avoid calling `ScalarValue::size_of_vec` by Self.size.
420452 size_of_orderings : usize ,
421453
422- // buffer for `get_filtered_min_of_each_group `
454+ // buffer for `get_filtered_extreme_of_each_group `
423455 // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
424456 // only valid if filter_min_of_each_group_buf.1[group_idx] == true
425457 extreme_of_each_group_buf : ( Vec < usize > , BooleanBufferBuilder ) ,
0 commit comments