Skip to content

Commit 7e20eb7

Browse files
authored
[branch-52] perf: Cache num_output_rows in sort merge join to avoid O(n) recount (#20478) (#20936)
1 parent e5547e2 commit 7e20eb7

1 file changed

Lines changed: 13 additions & 16 deletions

File tree

  • datafusion/physical-plan/src/joins/sort_merge_join

datafusion/physical-plan/src/joins/sort_merge_join/stream.rs

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ pub(super) struct StreamedBatch {
128128
pub join_arrays: Vec<ArrayRef>,
129129
/// Chunks of indices from buffered side (may be nulls) joined to streamed
130130
pub output_indices: Vec<StreamedJoinedChunk>,
131+
/// Total number of output rows across all chunks in `output_indices`
132+
pub num_output_rows: usize,
131133
/// Index of currently scanned batch from buffered data
132134
pub buffered_batch_idx: Option<usize>,
133135
/// Indices that found a match for the given join filter
@@ -144,6 +146,7 @@ impl StreamedBatch {
144146
idx: 0,
145147
join_arrays,
146148
output_indices: vec![],
149+
num_output_rows: 0,
147150
buffered_batch_idx: None,
148151
join_filter_matched_idxs: HashSet::new(),
149152
}
@@ -155,17 +158,15 @@ impl StreamedBatch {
155158
idx: 0,
156159
join_arrays: vec![],
157160
output_indices: vec![],
161+
num_output_rows: 0,
158162
buffered_batch_idx: None,
159163
join_filter_matched_idxs: HashSet::new(),
160164
}
161165
}
162166

163167
/// Number of unfrozen output pairs in this streamed batch
164168
fn num_output_rows(&self) -> usize {
165-
self.output_indices
166-
.iter()
167-
.map(|chunk| chunk.streamed_indices.len())
168-
.sum()
169+
self.num_output_rows
169170
}
170171

171172
/// Appends new pair consisting of current streamed index and `buffered_idx`
@@ -175,20 +176,20 @@ impl StreamedBatch {
175176
buffered_batch_idx: Option<usize>,
176177
buffered_idx: Option<usize>,
177178
batch_size: usize,
178-
num_unfrozen_pairs: usize,
179179
) {
180180
// If no current chunk exists or current chunk is not for current buffered batch,
181181
// create a new chunk
182182
if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx
183183
{
184184
// Compute capacity only when creating a new chunk (infrequent operation).
185185
// The capacity is the remaining space to reach batch_size.
186-
// This should always be >= 1 since we only call this when num_unfrozen_pairs < batch_size.
186+
// This should always be >= 1 since we only call this when num_output_rows < batch_size.
187187
debug_assert!(
188-
batch_size > num_unfrozen_pairs,
189-
"batch_size ({batch_size}) must be > num_unfrozen_pairs ({num_unfrozen_pairs})"
188+
batch_size > self.num_output_rows,
189+
"batch_size ({batch_size}) must be > num_output_rows ({})",
190+
self.num_output_rows
190191
);
191-
let capacity = batch_size - num_unfrozen_pairs;
192+
let capacity = batch_size - self.num_output_rows;
192193
self.output_indices.push(StreamedJoinedChunk {
193194
buffered_batch_idx,
194195
streamed_indices: UInt64Builder::with_capacity(capacity),
@@ -205,6 +206,7 @@ impl StreamedBatch {
205206
} else {
206207
current_chunk.buffered_indices.append_null();
207208
}
209+
self.num_output_rows += 1;
208210
}
209211
}
210212

@@ -1134,13 +1136,10 @@ impl SortMergeJoinStream {
11341136
let scanning_idx = self.buffered_data.scanning_idx();
11351137
if join_streamed {
11361138
// Join streamed row and buffered row
1137-
// Pass batch_size and num_unfrozen_pairs to compute capacity only when
1138-
// creating a new chunk (when buffered_batch_idx changes), not on every iteration.
11391139
self.streamed_batch.append_output_pair(
11401140
Some(self.buffered_data.scanning_batch_idx),
11411141
Some(scanning_idx),
11421142
self.batch_size,
1143-
self.num_unfrozen_pairs(),
11441143
);
11451144
} else {
11461145
// Join nulls and buffered row for FULL join
@@ -1166,13 +1165,10 @@ impl SortMergeJoinStream {
11661165
// For Mark join we store a dummy id to indicate the row has a match
11671166
let scanning_idx = mark_row_as_match.then_some(0);
11681167

1169-
// Pass batch_size=1 and num_unfrozen_pairs=0 to get capacity of 1,
1170-
// since we only append a single null-joined pair here (not in a loop).
11711168
self.streamed_batch.append_output_pair(
11721169
scanning_batch_idx,
11731170
scanning_idx,
1174-
1,
1175-
0,
1171+
self.batch_size,
11761172
);
11771173
self.buffered_data.scanning_finish();
11781174
self.streamed_joined = true;
@@ -1469,6 +1465,7 @@ impl SortMergeJoinStream {
14691465
}
14701466

14711467
self.streamed_batch.output_indices.clear();
1468+
self.streamed_batch.num_output_rows = 0;
14721469

14731470
Ok(())
14741471
}

0 commit comments

Comments
 (0)