Skip to content

Commit 0f85477

Browse files
Merge pull request #3040 from AI-Hypercomputer:jimmytsai/fix-ga-in-sft-trainer
PiperOrigin-RevId: 865712234
2 parents 88805db + 9e9916f commit 0f85477

7 files changed

Lines changed: 33 additions & 5 deletions

File tree

src/MaxText/configs/sft-vision-chartqa.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
base_config: "base.yml"
1616

1717
use_sft: True
18+
use_tunix_gradient_accumulation: True
1819
use_multimodal: True
1920
# For vision, the prompt contains image, we only train on completion tokens
2021
sft_train_on_completion_only: True

src/MaxText/configs/sft-vision-slidevqa.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
base_config: "base.yml"
1616

1717
use_sft: True
18+
use_tunix_gradient_accumulation: True
1819
use_multimodal: True
1920
# For vision, the prompt contains image, we only train on completion tokens
2021
sft_train_on_completion_only: True

src/MaxText/configs/sft.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
base_config: "base.yml"
1616

1717
use_sft: True
18+
use_tunix_gradient_accumulation: True
1819
# sft_train_on_completion_only=False trains on both prompt and completion tokens; trains only on completion tokens otherwise
1920
sft_train_on_completion_only: True
2021
packing: True

src/MaxText/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,10 @@ class Optimizer(BaseModel):
10561056
gradient_accumulation_steps: PositiveInt = Field(
10571057
1, description="Number of steps to accumulate gradients before updating."
10581058
)
1059+
use_tunix_gradient_accumulation: bool = Field(
1060+
False,
1061+
description="Whether to use the Tunix implementation for gradient accumulation.",
1062+
)
10591063
gradient_clipping_threshold: NonNegativeFloat = Field(
10601064
1.0, description="The threshold for gradient clipping. 0 disables clipping."
10611065
)

src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,13 @@ def vision_sft_preprocessing_pipeline(
5454
"""pipeline for multimodal SFT with HF dataset"""
5555

5656
assert len(text_columns) == 2, f"Need two text_columns for query and response, received {text_columns=}"
57-
batch_size = global_batch_size // jax.process_count()
57+
# Tunix GA requires per-micro-batch slicing at the data level,
58+
# whereas Native GA processes the full batch and splits it internally.
59+
if config.use_tunix_gradient_accumulation:
60+
batch_size = global_batch_size // jax.process_count() // config.gradient_accumulation_steps
61+
else:
62+
batch_size = global_batch_size // jax.process_count()
63+
5864
if config.enable_data_shuffling:
5965
dataset = dataset.shuffle(seed=config.data_shuffle_seed)
6066

@@ -195,13 +201,21 @@ def preprocessing_pipeline(
195201
generate_padding_batch=False,
196202
use_dpo=None,
197203
use_sft=None,
204+
use_tunix_gradient_accumulation=False,
205+
num_microbatches=1,
198206
sft_train_on_completion_only=True,
199207
grain_worker_count=1, # only support 0 or 1
200208
max_segments_per_seq=None,
201209
):
202210
"""pipeline for preprocessing HF dataset"""
203211

204212
assert global_batch_size % global_mesh.size == 0, "Batch size should be divisible by number of global devices."
213+
# Tunix GA requires per-micro-batch slicing at the data level,
214+
# whereas Native GA processes the full batch and splits it internally.
215+
if use_tunix_gradient_accumulation:
216+
batch_size = global_batch_size // jax.process_count() // num_microbatches
217+
else:
218+
batch_size = global_batch_size // jax.process_count()
205219

206220
if shuffle:
207221
dataset = dataset.shuffle(seed=data_shuffle_seed)
@@ -303,15 +317,15 @@ def lists2array(x):
303317
max_segments = None
304318
operations.append(
305319
grain.experimental.PackAndBatchOperation(
306-
batch_size=global_batch_size // jax.process_count(),
320+
batch_size=batch_size,
307321
length_struct=length_struct,
308322
max_sequences_per_bin=max_segments,
309323
)
310324
)
311325
operations.append(_input_pipeline_utils.ReformatPacking(data_column_names))
312326
else:
313327
operations.append(_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
314-
operations.append(grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder))
328+
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder))
315329

316330
if shift and not use_dpo:
317331
operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1))
@@ -390,6 +404,8 @@ def make_hf_train_iterator(
390404
generate_padding_batch=config.generate_padding_batch_train,
391405
use_dpo=config.use_dpo,
392406
use_sft=config.use_sft,
407+
use_tunix_gradient_accumulation=config.use_tunix_gradient_accumulation,
408+
num_microbatches=config.gradient_accumulation_steps,
393409
sft_train_on_completion_only=config.sft_train_on_completion_only,
394410
chat_template_path=config.chat_template_path,
395411
max_segments_per_seq=config.max_segments_per_seq,
@@ -443,6 +459,7 @@ def make_hf_eval_iterator(
443459
generate_padding_batch=config.generate_padding_batch_eval,
444460
use_dpo=config.use_dpo,
445461
use_sft=config.use_sft,
462+
num_microbatches=config.gradient_accumulation_steps,
446463
sft_train_on_completion_only=config.sft_train_on_completion_only,
447464
chat_template_path=config.chat_template_path,
448465
max_segments_per_seq=config.max_segments_per_seq,

src/MaxText/sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def main(argv: Sequence[str]) -> None:
165165
os.environ["LIBTPU_INIT_ARGS"] = (
166166
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
167167
)
168-
config = pyconfig.initialize(argv)
168+
config = pyconfig.initialize(argv, use_tunix_gradient_accumulation=False)
169169
jax.config.update("jax_use_shardy_partitioner", config.shardy)
170170
max_utils.print_system_information()
171171
train_utils.validate_train_config(config)

src/MaxText/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,13 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
179179
# Zero1+GA to reduce communication overhead.
180180
# EPS was used to avoid division by zero, but it's not needed when gradient
181181
# accumulation is enabled since there's no division.
182-
if config.gradient_accumulation_steps > 1:
182+
if config.gradient_accumulation_steps > 1 and not config.use_tunix_gradient_accumulation:
183183
loss = total_loss
184184
else:
185+
# When using Tunix gradient accumulation, we revert to standard normalization.
186+
# Unlike the manual accumulation path above, Tunix (via optax.MultiSteps) expects
187+
# a normalized loss for each step. It handles the accumulation state
188+
# updates and scaling internally.
185189
loss = total_loss / (total_weights + EPS)
186190

187191
# Calculate and Add MTP Loss

0 commit comments

Comments
 (0)