|
35 | 35 | from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen |
36 | 36 | from MaxText.layers.quantizations import AqtQuantization as Quant |
37 | 37 | from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen |
38 | | -from MaxText.maxtext_utils import all_gather_over_fsdp |
39 | 38 |
|
40 | 39 | # ------------------------------------------------------------------------------ |
41 | 40 | # The network: Transformer Definitions |
@@ -517,120 +516,3 @@ def __call__( |
517 | 516 | return hidden_state, kv_caches |
518 | 517 |
|
519 | 518 | return logits |
520 | | - |
521 | | - |
522 | | -class ZeroOneTransformer(nn.Module): |
523 | | - """ |
524 | | - A wrapper for the base Transformer model designed to implement the Zero-1 |
525 | | - FSDP optimization. |
526 | | -
|
527 | | - The goal of this optimization is to reduce communication overhead. In the standard |
528 | | - FSDP implementation, an all-gather operation on the model weights is performed twice |
529 | | - for each gradient accumulation microbatch (once for the forward pass, once for the backward pass). |
530 | | - This class changes that behavior. When enabled, it performs the all-gather operation |
531 | | - only *once* per full gradient accumulation step. It gathers the full weights into |
532 | | - memory, runs all the microbatch forward and backward passes, and then releases the |
533 | | - full weights. This trades higher peak memory usage for significantly reduced |
534 | | - network communication, which can improve training speed if sufficient memory is |
535 | | - available. |
536 | | - """ |
537 | | - |
538 | | - config: Config |
539 | | - mesh: Mesh |
540 | | - quant: Quant |
541 | | - # Possible model_mode values can be found in MaxText.common_types. |
542 | | - # We generally use MaxText.common_types.MODEL_MODE_TRAIN or |
543 | | - # MaxText.common_types.MODEL_MODE_PREFILL for initializations here. |
544 | | - # TODO: Make model_mode required after confirming no users are affected. |
545 | | - model_mode: str = MODEL_MODE_TRAIN # May be different than the model_mode passed to __call__ |
546 | | - |
547 | | - def setup(self): |
548 | | - """Sets up the underlying Transformer model. |
549 | | -
|
550 | | - This method initializes the `self.model` attribute by calling the |
551 | | - `transformer_as_linen` factory function. |
552 | | - """ |
553 | | - self.model = transformer_as_linen(self.config, self.mesh, self.quant, self.model_mode) |
554 | | - |
555 | | - def __call__( |
556 | | - self, |
557 | | - decoder_input_tokens: jnp.ndarray, |
558 | | - decoder_positions: jnp.ndarray, |
559 | | - decoder_segment_ids=None, |
560 | | - encoder_images: None | jnp.ndarray = None, |
561 | | - encoder_image_masks: None | jnp.ndarray = None, |
562 | | - enable_dropout=True, |
563 | | - model_mode=MODEL_MODE_TRAIN, |
564 | | - previous_chunk=None, |
565 | | - true_length: None | int = None, |
566 | | - slot: None | int = None, |
567 | | - page_state: None | page_manager.PageState = None, |
568 | | - partition_spec=None, |
569 | | - decoder_target_tokens: None | jnp.ndarray = None, |
570 | | - decoder_target_mask: None | jnp.ndarray = None, |
571 | | - nnx_method: str | None = None, |
572 | | - ): |
573 | | - """Applies the Zero-1 FSDP wrapped Transformer model. |
574 | | -
|
575 | | - This method handles the all-gather operation for model weights before |
576 | | - applying the underlying Transformer model, and then releases them. |
577 | | -
|
578 | | - Args: |
579 | | - decoder_input_tokens: Input tokens for the decoder. |
580 | | - decoder_positions: Positional encodings for the decoder inputs. |
581 | | - decoder_segment_ids: Segment IDs for the decoder inputs (optional). |
582 | | - encoder_images: Encoder images for multimodal models (optional). |
583 | | - enable_dropout: Whether to enable dropout. Defaults to True. |
584 | | - previous_chunk: Previous chunk for incremental decoding (optional). |
585 | | - true_length: True length of the prompt before padding (optional). |
586 | | - slot: An integer representing the decode batch index selected for this |
587 | | - request (optional). |
588 | | - page_state: Page state for paged attention (optional). |
589 | | - partition_spec: Partition specification for FSDP all-gather. |
590 | | - decoder_target_tokens: Target tokens for the decoder (optional, used in |
591 | | - MTP). |
592 | | - decoder_target_mask: Target mask for the decoder (optional, used in MTP). |
593 | | - nnx_method: Method to call on the NNX module (optional). |
594 | | -
|
595 | | - Returns: |
596 | | - Logits from the Transformer model. |
597 | | - """ |
598 | | - if self.is_initializing(): |
599 | | - return self.model( |
600 | | - decoder_input_tokens=decoder_input_tokens, |
601 | | - decoder_positions=decoder_positions, |
602 | | - decoder_segment_ids=decoder_segment_ids, |
603 | | - encoder_images=encoder_images, |
604 | | - encoder_image_masks=encoder_image_masks, |
605 | | - enable_dropout=enable_dropout, |
606 | | - model_mode=model_mode, |
607 | | - previous_chunk=previous_chunk, |
608 | | - true_length=true_length, |
609 | | - slot=slot, |
610 | | - page_state=page_state, |
611 | | - ) |
612 | | - all_model_weights = all_gather_over_fsdp( |
613 | | - self.model.variables, |
614 | | - partition_spec, |
615 | | - mesh=self.mesh, |
616 | | - logical_axis_rules=self.config.logical_axis_rules, |
617 | | - ) |
618 | | - |
619 | | - return self.model.apply( |
620 | | - all_model_weights, |
621 | | - decoder_input_tokens=decoder_input_tokens, |
622 | | - decoder_positions=decoder_positions, |
623 | | - decoder_segment_ids=decoder_segment_ids, |
624 | | - encoder_images=encoder_images, |
625 | | - encoder_image_masks=encoder_image_masks, |
626 | | - enable_dropout=enable_dropout, |
627 | | - model_mode=model_mode, |
628 | | - previous_chunk=previous_chunk, |
629 | | - true_length=true_length, |
630 | | - slot=slot, |
631 | | - page_state=page_state, |
632 | | - mutable=False, |
633 | | - decoder_target_tokens=decoder_target_tokens, |
634 | | - decoder_target_mask=decoder_target_mask, |
635 | | - nnx_method=nnx_method, |
636 | | - ) |
0 commit comments