diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index b4af6e2c09a5c..c36274917d2f3 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -22,6 +22,7 @@ use std::fmt::{Debug, Formatter}; use std::pin::Pin; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::task::{Context, Poll}; use std::vec; @@ -71,7 +72,7 @@ use crate::sort_pushdown::SortOrderPushdownResult; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::stream::Stream; -use futures::{FutureExt, StreamExt, TryStreamExt, ready}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use log::trace; use parking_lot::Mutex; @@ -143,11 +144,183 @@ type MaybeBatch = Option>; type InputPartitionsToCurrentPartitionSender = Vec>; type InputPartitionsToCurrentPartitionReceiver = Vec>; -/// Output channel with its associated memory reservation and spill writer +/// One input task's collection of output channels (its send-side view of +/// every output partition). Owns the per-call helpers for coalescing, +/// finalizing, and sending. +/// +/// This complements [`PartitionChannels`] (the per-output-partition, +/// authoritative struct that owns `rx`, `spill_readers`, and the underlying +/// `Mutex` / `AtomicUsize`). Each [`OutputChannel`] +/// inside `inner` holds cloned `Arc`s pointing back at those shared +/// resources. +struct OutputChannels { + inner: HashMap, + metrics: RepartitionMetrics, +} + +impl OutputChannels { + fn new(inner: HashMap, metrics: RepartitionMetrics) -> Self { + Self { inner, metrics } + } + + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + fn metrics(&self) -> &RepartitionMetrics { + &self.metrics + } + + /// Push `batch` for `partition` through its shared coalescer (if any) + /// and ship any newly completed batches through the channel. + async fn coalesce_and_send( + &mut self, + partition: usize, + batch: RecordBatch, + ) -> Result<()> { + let Some(channel) = self.inner.get(&partition) else { + return Ok(()); + }; + let to_send = match &channel.shared_coalescer { + Some(shared) => shared.push_and_drain(batch)?, + None => vec![batch], + }; + for batch in to_send { + self.send_to_channel(partition, batch).await; + } + Ok(()) + } + + /// For each output partition this task still has, decrement the shared + /// active-senders counter. The last task to do so calls + /// [`SharedCoalescer::finalize`] and ships the residual. + async fn finalize(&mut self) -> Result<()> { + let partitions: Vec = self.inner.keys().copied().collect(); + for partition in partitions { + let Some(channel) = self.inner.get(&partition) else { + continue; + }; + let Some(shared) = channel.shared_coalescer.clone() else { + continue; + }; + for batch in shared.finalize()? { + self.send_to_channel(partition, batch).await; + } + } + Ok(()) + } + + /// Send a single batch through the channel for `partition`, applying + /// the memory reservation / spill-writer fallback. Removes the channel + /// from `self.inner` if the receiver has hung up. + async fn send_to_channel(&mut self, partition: usize, batch: RecordBatch) { + let size = batch.get_array_memory_size(); + let timer = self.metrics.send_time[partition].timer(); + + // Decide the payload outside of any await: never hold a MutexGuard + // across an await point. + let (payload, is_memory_batch) = { + let Some(channel) = self.inner.get(&partition) else { + timer.done(); + return; + }; + match channel.reservation.try_grow(size) { + Ok(_) => (Ok(RepartitionBatch::Memory(batch)), true), + Err(_) => match channel.spill_writer.push_batch(&batch) { + Ok(()) => (Ok(RepartitionBatch::Spilled), false), + Err(err) => (Err(err), false), + }, + } + }; + + let Some(channel) = self.inner.get(&partition) else { + timer.done(); + return; + }; + let send_err = channel.sender.send(Some(payload)).await.is_err(); + if send_err { + if is_memory_batch && let Some(channel) = self.inner.get(&partition) { + channel.reservation.shrink(size); + } + self.inner.remove(&partition); + } + timer.done(); + } +} + +/// A producer-side coalescer shared across all input tasks targeting a +/// single output partition. +/// +/// Bundles the [`LimitedBatchCoalescer`] (behind a [`Mutex`]) with the +/// active-sender counter that tracks how many input tasks may still push +/// into it. The last task to call [`Self::finalize`] is the one that +/// finalizes the coalescer and ships the residual batch. +/// +/// Cheap to [`Clone`]: both fields are [`Arc`]s. +#[derive(Clone)] +struct SharedCoalescer { + inner: Arc>, + active_senders: Arc, +} + +impl SharedCoalescer { + fn new( + schema: SchemaRef, + target_batch_size: usize, + fetch: Option, + num_senders: usize, + ) -> Self { + Self { + inner: Arc::new(Mutex::new(LimitedBatchCoalescer::new( + schema, + target_batch_size, + fetch, + ))), + active_senders: Arc::new(AtomicUsize::new(num_senders)), + } + } + + /// Push `batch` into the coalescer and drain any newly completed + /// batches. The mutex is held only briefly. + fn push_and_drain(&self, batch: RecordBatch) -> Result> { + let mut acc = Vec::new(); + let mut c = self.inner.lock(); + c.push_batch(batch)?; + while let Some(b) = c.next_completed_batch() { + acc.push(b); + } + Ok(acc) + } + + /// Decrement the active-senders counter. If this caller was the last + /// sender, finalize the coalescer and return its residual batches; if + /// other senders are still active, return `Ok(None)`. + fn finalize(&self) -> Result> { + let was_last = self.active_senders.fetch_sub(1, Ordering::AcqRel) == 1; + if !was_last { + return Ok(vec![]); + } + let mut acc = Vec::new(); + let mut c = self.inner.lock(); + c.finish()?; + while let Some(b) = c.next_completed_batch() { + acc.push(b); + } + Ok(acc) + } +} + +/// Output channel with its associated memory reservation and spill writer. +/// +/// `coalescer` is `None` for preserve-order mode, where downstream +/// [`StreamingMergeBuilder`] performs the batching; otherwise it's a +/// [`SharedCoalescer`] cloned from the per-partition one held by +/// [`PartitionChannels`]. struct OutputChannel { sender: DistributionSender, reservation: SharedMemoryReservation, spill_writer: SpillPoolWriter, + shared_coalescer: Option, } /// Channels and resources for a single output partition. @@ -178,6 +351,10 @@ struct PartitionChannels { rx: InputPartitionsToCurrentPartitionReceiver, /// Memory reservation for this output partition reservation: SharedMemoryReservation, + /// Shared coalescer used by all input tasks targeting this output + /// partition. `None` in preserve-order mode (downstream + /// `StreamingMergeBuilder` handles batching). + shared_coalescer: Option, /// Spill writers for writing spilled data. /// SpillPoolWriter is Clone, so multiple writers can share state in non-preserve-order mode. spill_writers: Vec, @@ -272,6 +449,7 @@ impl RepartitionExecState { name: &str, context: &Arc, spill_manager: SpillManager, + fetch: Option, ) -> Result<&mut ConsumingInputStreamsState> { let streams_and_metrics = match self { RepartitionExecState::NotInitialized => { @@ -347,6 +525,19 @@ impl RepartitionExecState { .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager))) .unzip(); + // Coalesce on the producer side, before the channel's gate, so + // the consumer never sees the per-input-task small batches. + // Skip in preserve-order mode: each input has its own dedicated + // channel and `StreamingMergeBuilder` handles batching. + let shared_coalescer = (!preserve_order).then(|| { + SharedCoalescer::new( + input.schema(), + context.session_config().batch_size(), + fetch, + num_input_partitions, + ) + }); + channels.insert( partition, PartitionChannels { @@ -355,6 +546,7 @@ impl RepartitionExecState { reservation, spill_readers, spill_writers, + shared_coalescer, }, ); } @@ -377,6 +569,7 @@ impl RepartitionExecState { reservation: Arc::clone(&channels.reservation), spill_writer: channels.spill_writers[spill_writer_idx] .clone(), + shared_coalescer: channels.shared_coalescer.clone(), }, ) }) @@ -390,9 +583,8 @@ impl RepartitionExecState { let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( stream, - txs, + OutputChannels::new(txs, metrics), partitioning.clone(), - metrics, // preserve_order depends on partition index to start from 0 if preserve_order { 0 } else { i }, num_input_partitions, @@ -1055,6 +1247,7 @@ impl ExecutionPlan for RepartitionExec { let name = self.name().to_owned(); let schema = self.schema(); let schema_captured = Arc::clone(&schema); + let fetch = self.fetch(); let spill_manager = SpillManager::new( Arc::clone(&context.runtime_env()), @@ -1090,6 +1283,7 @@ impl ExecutionPlan for RepartitionExec { &name, &context, spill_manager.clone(), + fetch, )?; // now return stream for the specified *output* partition which will @@ -1132,7 +1326,6 @@ impl ExecutionPlan for RepartitionExec { spill_stream, 1, // Each receiver handles one input partition BaselineMetrics::new(&metrics, partition), - None, // subsequent merge sort already does batching https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L286 )) as SendableRecordBatchStream }) .collect::>(); @@ -1171,7 +1364,6 @@ impl ExecutionPlan for RepartitionExec { spill_stream, num_input_partitions, BaselineMetrics::new(&metrics, partition), - Some(context.session_config().batch_size()), )) as SendableRecordBatchStream) } }) @@ -1425,24 +1617,24 @@ impl RepartitionExec { /// `output_channels` holds the output sending channels for each output partition async fn pull_from_input( mut stream: SendableRecordBatchStream, - mut output_channels: HashMap, + mut output_channels: OutputChannels, partitioning: Partitioning, - metrics: RepartitionMetrics, input_partition: usize, num_input_partitions: usize, ) -> Result<()> { + let repartition_time = output_channels.metrics().repartition_time.clone(); let mut partitioner = match &partitioning { Partitioning::Hash(exprs, num_partitions) => { BatchPartitioner::new_hash_partitioner( exprs.clone(), *num_partitions, - metrics.repartition_time.clone(), + repartition_time, )? } Partitioning::RoundRobinBatch(num_partitions) => { BatchPartitioner::new_round_robin_partitioner( *num_partitions, - metrics.repartition_time.clone(), + repartition_time, input_partition, num_input_partitions, ) @@ -1456,7 +1648,7 @@ impl RepartitionExec { let mut batches_until_yield = partitioner.num_partitions(); while !output_channels.is_empty() { // fetch the next batch - let timer = metrics.fetch_time.timer(); + let timer = output_channels.metrics().fetch_time.timer(); let result = stream.next().await; timer.done(); @@ -1473,36 +1665,7 @@ impl RepartitionExec { for res in partitioner.partition_iter(batch)? { let (partition, batch) = res?; - let size = batch.get_array_memory_size(); - - let timer = metrics.send_time[partition].timer(); - // if there is still a receiver, send to it - if let Some(channel) = output_channels.get_mut(&partition) { - let (batch_to_send, is_memory_batch) = - match channel.reservation.try_grow(size) { - Ok(_) => { - // Memory available - send in-memory batch - (RepartitionBatch::Memory(batch), true) - } - Err(_) => { - // We're memory limited - spill to SpillPool - // SpillPool handles file handle reuse and rotation - channel.spill_writer.push_batch(&batch)?; - // Send marker indicating batch was spilled - (RepartitionBatch::Spilled, false) - } - }; - - if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() { - // If the other end has hung up, it was an early shutdown (e.g. LIMIT) - // Only shrink memory if it was a memory batch - if is_memory_batch { - channel.reservation.shrink(size); - } - output_channels.remove(&partition); - } - } - timer.done(); + output_channels.coalesce_and_send(partition, batch).await?; } // If the input stream is endless, we may spin forever and @@ -1529,6 +1692,12 @@ impl RepartitionExec { } } + // End of input for this task. For each output partition we still + // have a channel to, decrement the active-senders counter; whoever + // sees the count drop to zero is the last input task and must + // finalize the shared coalescer and ship its residual. + output_channels.finalize().await?; + // Spill writers will auto-finalize when dropped // No need for explicit flush Ok(()) @@ -1660,13 +1829,9 @@ struct PerPartitionStream { /// Execution metrics baseline_metrics: BaselineMetrics, - - /// None for sort preserving variant (merge sort already does coalescing) - batch_coalescer: Option, } impl PerPartitionStream { - #[expect(clippy::too_many_arguments)] fn new( schema: SchemaRef, receiver: DistributionReceiver, @@ -1675,10 +1840,7 @@ impl PerPartitionStream { spill_stream: SendableRecordBatchStream, num_input_partitions: usize, baseline_metrics: BaselineMetrics, - batch_size: Option, ) -> Self { - let batch_coalescer = - batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None)); Self { schema, receiver, @@ -1688,7 +1850,6 @@ impl PerPartitionStream { state: StreamState::ReadingMemory, remaining_partitions: num_input_partitions, baseline_metrics, - batch_coalescer, } } @@ -1770,43 +1931,6 @@ impl PerPartitionStream { } } } - - fn poll_next_and_coalesce( - self: &mut Pin<&mut Self>, - cx: &mut Context<'_>, - coalescer: &mut LimitedBatchCoalescer, - ) -> Poll>> { - let cloned_time = self.baseline_metrics.elapsed_compute().clone(); - let mut completed = false; - - loop { - if let Some(batch) = coalescer.next_completed_batch() { - return Poll::Ready(Some(Ok(batch))); - } - if completed { - return Poll::Ready(None); - } - - match ready!(self.poll_next_inner(cx)) { - Some(Ok(batch)) => { - let _timer = cloned_time.timer(); - if let Err(err) = coalescer.push_batch(batch) { - return Poll::Ready(Some(Err(err))); - } - } - Some(err) => { - return Poll::Ready(Some(err)); - } - None => { - completed = true; - let _timer = cloned_time.timer(); - if let Err(err) = coalescer.finish() { - return Poll::Ready(Some(Err(err))); - } - } - } - } - } } impl Stream for PerPartitionStream { @@ -1816,13 +1940,7 @@ impl Stream for PerPartitionStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll; - if let Some(mut coalescer) = self.batch_coalescer.take() { - poll = self.poll_next_and_coalesce(cx, &mut coalescer); - self.batch_coalescer = Some(coalescer); - } else { - poll = self.poll_next_inner(cx); - } + let poll = self.poll_next_inner(cx); self.baseline_metrics.record_poll(poll) } } @@ -2526,13 +2644,17 @@ mod tests { let input_partitions = vec![partition]; let partitioning = Partitioning::RoundRobinBatch(4); - // Set up context with moderate memory limit to force partial spilling - // 2KB should allow some batches in memory but force others to spill + // With `batch_size = 1024` and a single UInt32 column, each + // coalesced residual is ~4 KiB. An 8 KiB pool fits one and forces + // the rest to spill. let runtime = RuntimeEnvBuilder::default() - .with_memory_limit(2 * 1024, 1.0) + .with_memory_limit(8 * 1024, 1.0) .build_arc()?; - let task_ctx = TaskContext::default().with_runtime(runtime); + let session_config = SessionConfig::new().with_batch_size(1024); + let task_ctx = TaskContext::default() + .with_runtime(runtime) + .with_session_config(session_config); let task_ctx = Arc::new(task_ctx); // create physical plan