|
| 1 | +<!-- |
| 2 | + Copyright 2026 Google LLC |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | + --> |
| 16 | + |
| 17 | +# Batch Size |
| 18 | + |
| 19 | +This document explains the different concepts of "batch size" within MaxText and how to configure them to tune performance and manage memory. |
| 20 | + |
| 21 | +## Per-Device Batch Size |
| 22 | + |
| 23 | +`per_device_batch_size` is the number of training examples processed by a single device in one forward and backward pass. This value impacts the memory usage on each device and is a configuration parameter in `configs/base.yml` |
| 24 | + |
| 25 | +## Global Batch Size |
| 26 | + |
| 27 | +`global_batch_to_train` is the total number of training examples processed before the optimizer performs a single weight update. It is the effective batch size for training, calculated as: |
| 28 | + |
| 29 | +`global_batch_to_train = per_device_batch_size x number_of_devices x gradient_accumulation_steps` |
| 30 | + |
| 31 | +You can set `per_device_batch_size` and `gradient_accumulation_steps` in `configs/base.yml`. |
| 32 | + |
| 33 | +`global_batch_to_load` is the total number of examples the data input pipeline loads from storage at once. It can be larger than the training batch size to optimize I/O performance, and is calculated as: |
| 34 | + |
| 35 | +`global_batch_to_load` = `global_batch_size_to_train_on x expansion_factor_real_data` |
| 36 | + |
| 37 | +When `expansion_factor_real_data > 1`, only a subset of hosts read data from the source (e.g., a GCS bucket). These "loading hosts" read more data than they need for their own devices and distribute the surplus to other "non-loading" hosts. This reduces the number of concurrent connections to the data source, which can significantly improve I/O throughput. When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job. Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/data_input_grain.md#using-grain. |
| 38 | + |
| 39 | +## Gradient Accumulation Steps |
| 40 | + |
| 41 | +`gradient_accumulation_steps` defines how many forward/backward passes are performed before the optimizer updates the model weights. The gradients from each pass are accumulated (summed). It is discussed in more detail [here](https://maxtext.readthedocs.io/en/latest/reference/core_concepts/tiling.html#gradient-accumulation). |
| 42 | + |
| 43 | +For example, if `gradient_accumulation_steps` is set to `4`, the model will execute four forward and backward passes, sum the gradients, and then apply a single optimizer step. This achieves the same effective global batch size as quadrupling the `per_device_batch_size` with significantly less memory, but can potentially lead to lower MFU. |
| 44 | + |
| 45 | +## Pipeline Microbatches |
| 46 | + |
| 47 | +When pipeline parallelism is enabled, the global batch is split into smaller chunks called **microbatches**. These are fed into the pipeline sequentially, allowing different stages of the model to work on different microbatches simultaneously. |
| 48 | + |
| 49 | +The `num_pipeline_microbatches` parameter in `configs/base.yml` configures how many of these smaller chunks the global batch is divided into. It must be a multiple of the total number of pipeline stages (`ici_pipeline_parallelism` x `dcn_pipeline_parallelism`). |
| 50 | + |
| 51 | +The choice of `num_pipeline_microbatches` is a trade-off between reducing pipeline idle time and the computational efficiency within each stage. More microbatches reduces the "Pipeline Bubble" but leads to smaller matrix multiplications within each stage. Very small operations may not fully saturate the compute units of the hardware, potentially lowering arithmetic intensity. |
| 52 | + |
| 53 | +## Batch Size Ramp-up |
| 54 | + |
| 55 | +MaxText supports gradually increasing the batch size during the initial phase of training to improve stability, a technique also used in frameworks like [NVIDIA's NeMo Megatron](https://docs.nvidia.com/nemo-framework/user-guide/24.09/nemotoolkit/nlp/nemo_megatron/rampup_batch_size.html). This can be configured in `configs/base.yml`: |
| 56 | + |
| 57 | +- Setting `enable_rampup_batch_size=True` activates the ramp-up process. |
| 58 | +- `per_device_batch_size_start`: The minimum batch size to start training on. |
| 59 | +- `per_device_batch_size`: The target batch size to stabilize on at the end of the ramp-up process. |
| 60 | +- `per_device_batch_size_increment`: How much batch size increases for each ramp-up stage. |
| 61 | +- `global_rampup_samples`: The total number of samples to process across all ramp-up stages. |
| 62 | + |
| 63 | +The ramp-up is based on the number of samples processed, not the number of training steps. Each stage processes an equal number of samples before batch size is increased. |
| 64 | + |
| 65 | +The number of stages is determined by: |
| 66 | + |
| 67 | +`num_increments = (per_device_batch_size - per_device_batch_size_start) / per_device_batch_size_increment` |
| 68 | + |
| 69 | +The total number of ramp-up samples (`global_rampup_samples`) is then distributed equally across these stages. The number of samples processed in each stage is determined by: |
| 70 | + |
| 71 | +`samples_per_increment = global_rampup_samples / num_increments` |
| 72 | + |
| 73 | +During training, the model processes `samples_per_increment` samples at the current batch size. Once this threshold is reached, the batch size is increased by `per_device_batch_size_increment` until the target `per_device_batch_size` is reached. This entire process is managed by the `RampupBatchManager` class. |
| 74 | + |
| 75 | +## Reinforcement Learning (RL) Batch Size |
| 76 | + |
| 77 | +The batch size parameters for RL training are defined in `configs/post_train/rl.yml`: |
| 78 | + |
| 79 | +- `batch_size` refers to the number of unique prompts loaded from the dataset in a single batch. For instance, `batch_size=1` means one prompt is processed at a time by the data loader. |
| 80 | + |
| 81 | +- `num_generations` is the number of times the policy generates multiple responses for a given prompt within a single training step. |
| 82 | + |
| 83 | +- The effective training batch is the total number of prompt-response pairs used in a training step, calculated as `batch_size x num_generations`. It is determined by the number of responses generated for each prompt, which is configured by `num_generations`. |
| 84 | + |
| 85 | +- `micro_batch_size` is used to split the batch of prompt-response pairs into smaller chunks for memory management. This enables overlapping the rollout phase (generating responses) of one micro-batch with the training phase (updating model weights) of the previous micro-batch, which can improve hardware utilization. A value of `-1` means no micro-batching is enabled. |
0 commit comments