Skip to content

Commit fbcee8f

Browse files
Merge pull request #2744 from AI-Hypercomputer:chengnuojin-explicit-pipeline
PiperOrigin-RevId: 850269548
2 parents 2887b75 + 112f8c3 commit fbcee8f

12 files changed

Lines changed: 370 additions & 324 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,6 @@ pipeline_delay_activation_forwarding: False # This delays the activation forward
262262
# the communication and compute in each iteration are now independent. However this comes at the cost of doubling the pipeline bubble,
263263
# and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay).
264264

265-
model_fsdp_ag_once: False # This controls whether the Zero-1 optimization is active.
266-
# This is a memory/time tradeoff - True: This is Zero-1 Sharding. Use ZeroOneTransformer to gather weights once per gradient step.
267-
# False: This is Zero-3 Sharing. Use the standard Transformer, which gathers for each microbatch's fwd/bwd pass.
268265
pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights over FSDP before the first pipeline iteration.
269266
# This is a memory/time tradeoff - we now have to store the FSDP gathered weights and gradients (typically in bf16), as opposed
270267
# to only one stage's worth, however we only execute one all-gather and reduce across per repeat, as opposed

src/MaxText/configs/types.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,10 +1494,6 @@ class DerivedValues(BaseModel):
14941494
None,
14951495
description="Boolean flag indicating if pipeline parallelism is active across ICI or DCN.",
14961496
)
1497-
model_fsdp_ag_once: bool = Field(
1498-
False,
1499-
description="An alias for `pipeline_fsdp_ag_once` for backward compatibility.",
1500-
)
15011497

15021498
context_parallel_size: None | int = Field(
15031499
None,
@@ -1990,8 +1986,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
19901986
):
19911987
self.logical_axis_rules.append(["aqt_amax_history", ("stage",)])
19921988

1993-
self.model_fsdp_ag_once = self.pipeline_fsdp_ag_once # Backward compatibility alias
1994-
19951989
# H. RUN ALL CROSS-FIELD VALIDATIONS
19961990
if self.load_parameters_path and self.load_full_state_path:
19971991
raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.")

src/MaxText/layers/decoders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -717,11 +717,11 @@ def __call__(
717717
)
718718
if cfg.using_pipeline_parallelism:
719719
if cfg.pipeline_fsdp_ag_once:
720-
partition_spec = self.pipeline_module.get_weight_sharding(
720+
logical_partition_spec = self.pipeline_module.get_weight_sharding(
721721
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
722722
)
723723
else:
724-
partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
724+
logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
725725
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
726726
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
727727
dense_layer = RemattedBlockLayers[0]
@@ -750,9 +750,9 @@ def __call__(
750750
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
751751
model_mode=model_mode,
752752
)(y, *broadcast_args)
753-
y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec)
753+
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
754754
else: # Not DeepSeek
755-
y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec)
755+
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
756756
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
757757
if remaining_layers > 0:
758758
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)

src/MaxText/layers/models.py

Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen
3636
from MaxText.layers.quantizations import AqtQuantization as Quant
3737
from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen
38-
from MaxText.maxtext_utils import all_gather_over_fsdp
3938

4039
# ------------------------------------------------------------------------------
4140
# The network: Transformer Definitions
@@ -517,120 +516,3 @@ def __call__(
517516
return hidden_state, kv_caches
518517

519518
return logits
520-
521-
522-
class ZeroOneTransformer(nn.Module):
523-
"""
524-
A wrapper for the base Transformer model designed to implement the Zero-1
525-
FSDP optimization.
526-
527-
The goal of this optimization is to reduce communication overhead. In the standard
528-
FSDP implementation, an all-gather operation on the model weights is performed twice
529-
for each gradient accumulation microbatch (once for the forward pass, once for the backward pass).
530-
This class changes that behavior. When enabled, it performs the all-gather operation
531-
only *once* per full gradient accumulation step. It gathers the full weights into
532-
memory, runs all the microbatch forward and backward passes, and then releases the
533-
full weights. This trades higher peak memory usage for significantly reduced
534-
network communication, which can improve training speed if sufficient memory is
535-
available.
536-
"""
537-
538-
config: Config
539-
mesh: Mesh
540-
quant: Quant
541-
# Possible model_mode values can be found in MaxText.common_types.
542-
# We generally use MaxText.common_types.MODEL_MODE_TRAIN or
543-
# MaxText.common_types.MODEL_MODE_PREFILL for initializations here.
544-
# TODO: Make model_mode required after confirming no users are affected.
545-
model_mode: str = MODEL_MODE_TRAIN # May be different than the model_mode passed to __call__
546-
547-
def setup(self):
548-
"""Sets up the underlying Transformer model.
549-
550-
This method initializes the `self.model` attribute by calling the
551-
`transformer_as_linen` factory function.
552-
"""
553-
self.model = transformer_as_linen(self.config, self.mesh, self.quant, self.model_mode)
554-
555-
def __call__(
556-
self,
557-
decoder_input_tokens: jnp.ndarray,
558-
decoder_positions: jnp.ndarray,
559-
decoder_segment_ids=None,
560-
encoder_images: None | jnp.ndarray = None,
561-
encoder_image_masks: None | jnp.ndarray = None,
562-
enable_dropout=True,
563-
model_mode=MODEL_MODE_TRAIN,
564-
previous_chunk=None,
565-
true_length: None | int = None,
566-
slot: None | int = None,
567-
page_state: None | page_manager.PageState = None,
568-
partition_spec=None,
569-
decoder_target_tokens: None | jnp.ndarray = None,
570-
decoder_target_mask: None | jnp.ndarray = None,
571-
nnx_method: str | None = None,
572-
):
573-
"""Applies the Zero-1 FSDP wrapped Transformer model.
574-
575-
This method handles the all-gather operation for model weights before
576-
applying the underlying Transformer model, and then releases them.
577-
578-
Args:
579-
decoder_input_tokens: Input tokens for the decoder.
580-
decoder_positions: Positional encodings for the decoder inputs.
581-
decoder_segment_ids: Segment IDs for the decoder inputs (optional).
582-
encoder_images: Encoder images for multimodal models (optional).
583-
enable_dropout: Whether to enable dropout. Defaults to True.
584-
previous_chunk: Previous chunk for incremental decoding (optional).
585-
true_length: True length of the prompt before padding (optional).
586-
slot: An integer representing the decode batch index selected for this
587-
request (optional).
588-
page_state: Page state for paged attention (optional).
589-
partition_spec: Partition specification for FSDP all-gather.
590-
decoder_target_tokens: Target tokens for the decoder (optional, used in
591-
MTP).
592-
decoder_target_mask: Target mask for the decoder (optional, used in MTP).
593-
nnx_method: Method to call on the NNX module (optional).
594-
595-
Returns:
596-
Logits from the Transformer model.
597-
"""
598-
if self.is_initializing():
599-
return self.model(
600-
decoder_input_tokens=decoder_input_tokens,
601-
decoder_positions=decoder_positions,
602-
decoder_segment_ids=decoder_segment_ids,
603-
encoder_images=encoder_images,
604-
encoder_image_masks=encoder_image_masks,
605-
enable_dropout=enable_dropout,
606-
model_mode=model_mode,
607-
previous_chunk=previous_chunk,
608-
true_length=true_length,
609-
slot=slot,
610-
page_state=page_state,
611-
)
612-
all_model_weights = all_gather_over_fsdp(
613-
self.model.variables,
614-
partition_spec,
615-
mesh=self.mesh,
616-
logical_axis_rules=self.config.logical_axis_rules,
617-
)
618-
619-
return self.model.apply(
620-
all_model_weights,
621-
decoder_input_tokens=decoder_input_tokens,
622-
decoder_positions=decoder_positions,
623-
decoder_segment_ids=decoder_segment_ids,
624-
encoder_images=encoder_images,
625-
encoder_image_masks=encoder_image_masks,
626-
enable_dropout=enable_dropout,
627-
model_mode=model_mode,
628-
previous_chunk=previous_chunk,
629-
true_length=true_length,
630-
slot=slot,
631-
page_state=page_state,
632-
mutable=False,
633-
decoder_target_tokens=decoder_target_tokens,
634-
decoder_target_mask=decoder_target_mask,
635-
nnx_method=nnx_method,
636-
)

0 commit comments

Comments
 (0)