@@ -128,6 +128,8 @@ pub(super) struct StreamedBatch {
128128 pub join_arrays : Vec < ArrayRef > ,
129129 /// Chunks of indices from buffered side (may be nulls) joined to streamed
130130 pub output_indices : Vec < StreamedJoinedChunk > ,
131+ /// Total number of output rows across all chunks in `output_indices`
132+ pub num_output_rows : usize ,
131133 /// Index of currently scanned batch from buffered data
132134 pub buffered_batch_idx : Option < usize > ,
133135 /// Indices that found a match for the given join filter
@@ -144,6 +146,7 @@ impl StreamedBatch {
144146 idx : 0 ,
145147 join_arrays,
146148 output_indices : vec ! [ ] ,
149+ num_output_rows : 0 ,
147150 buffered_batch_idx : None ,
148151 join_filter_matched_idxs : HashSet :: new ( ) ,
149152 }
@@ -155,17 +158,15 @@ impl StreamedBatch {
155158 idx : 0 ,
156159 join_arrays : vec ! [ ] ,
157160 output_indices : vec ! [ ] ,
161+ num_output_rows : 0 ,
158162 buffered_batch_idx : None ,
159163 join_filter_matched_idxs : HashSet :: new ( ) ,
160164 }
161165 }
162166
163167 /// Number of unfrozen output pairs in this streamed batch
164168 fn num_output_rows ( & self ) -> usize {
165- self . output_indices
166- . iter ( )
167- . map ( |chunk| chunk. streamed_indices . len ( ) )
168- . sum ( )
169+ self . num_output_rows
169170 }
170171
171172 /// Appends new pair consisting of current streamed index and `buffered_idx`
@@ -175,20 +176,20 @@ impl StreamedBatch {
175176 buffered_batch_idx : Option < usize > ,
176177 buffered_idx : Option < usize > ,
177178 batch_size : usize ,
178- num_unfrozen_pairs : usize ,
179179 ) {
180180 // If no current chunk exists or current chunk is not for current buffered batch,
181181 // create a new chunk
182182 if self . output_indices . is_empty ( ) || self . buffered_batch_idx != buffered_batch_idx
183183 {
184184 // Compute capacity only when creating a new chunk (infrequent operation).
185185 // The capacity is the remaining space to reach batch_size.
186- // This should always be >= 1 since we only call this when num_unfrozen_pairs < batch_size.
186+ // This should always be >= 1 since we only call this when num_output_rows < batch_size.
187187 debug_assert ! (
188- batch_size > num_unfrozen_pairs,
189- "batch_size ({batch_size}) must be > num_unfrozen_pairs ({num_unfrozen_pairs})"
188+ batch_size > self . num_output_rows,
189+ "batch_size ({batch_size}) must be > num_output_rows ({})" ,
190+ self . num_output_rows
190191 ) ;
191- let capacity = batch_size - num_unfrozen_pairs ;
192+ let capacity = batch_size - self . num_output_rows ;
192193 self . output_indices . push ( StreamedJoinedChunk {
193194 buffered_batch_idx,
194195 streamed_indices : UInt64Builder :: with_capacity ( capacity) ,
@@ -205,6 +206,7 @@ impl StreamedBatch {
205206 } else {
206207 current_chunk. buffered_indices . append_null ( ) ;
207208 }
209+ self . num_output_rows += 1 ;
208210 }
209211}
210212
@@ -1134,13 +1136,10 @@ impl SortMergeJoinStream {
11341136 let scanning_idx = self . buffered_data . scanning_idx ( ) ;
11351137 if join_streamed {
11361138 // Join streamed row and buffered row
1137- // Pass batch_size and num_unfrozen_pairs to compute capacity only when
1138- // creating a new chunk (when buffered_batch_idx changes), not on every iteration.
11391139 self . streamed_batch . append_output_pair (
11401140 Some ( self . buffered_data . scanning_batch_idx ) ,
11411141 Some ( scanning_idx) ,
11421142 self . batch_size ,
1143- self . num_unfrozen_pairs ( ) ,
11441143 ) ;
11451144 } else {
11461145 // Join nulls and buffered row for FULL join
@@ -1166,13 +1165,10 @@ impl SortMergeJoinStream {
11661165 // For Mark join we store a dummy id to indicate the row has a match
11671166 let scanning_idx = mark_row_as_match. then_some ( 0 ) ;
11681167
1169- // Pass batch_size=1 and num_unfrozen_pairs=0 to get capacity of 1,
1170- // since we only append a single null-joined pair here (not in a loop).
11711168 self . streamed_batch . append_output_pair (
11721169 scanning_batch_idx,
11731170 scanning_idx,
1174- 1 ,
1175- 0 ,
1171+ self . batch_size ,
11761172 ) ;
11771173 self . buffered_data . scanning_finish ( ) ;
11781174 self . streamed_joined = true ;
@@ -1469,6 +1465,7 @@ impl SortMergeJoinStream {
14691465 }
14701466
14711467 self . streamed_batch . output_indices . clear ( ) ;
1468+ self . streamed_batch . num_output_rows = 0 ;
14721469
14731470 Ok ( ( ) )
14741471 }
0 commit comments