Skip to content

Commit 1ce4481

Browse files
gagikaNuojCheng
authored andcommitted
Enable grain input pipeline save and restore for distillation.
simple fix on debug sharding log add all gather insertion per repeat working all gather insertion clean version fsdp+pp bug free add bsw checkpoint split bsw all gather into two add custom vjp
1 parent 2f8c473 commit 1ce4481

5 files changed

Lines changed: 425 additions & 158 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights
299299
# It may be useful to do the reverse when the layers_per_stage is very large.
300300
# The below settings only have effect when using pipeline parallelism.
301301
scan_pipeline_iterations: True
302+
scan_pipeline_repeats: True
302303
scan_layers_per_stage: False
303304
set_remat_policy_on_pipeline_iterations: True
304305
set_remat_policy_on_layers_per_stage: False
@@ -922,7 +923,7 @@ xprof_e2e_enable_fw_power_level_event: False
922923
xprof_e2e_enable_fw_thermal_event: False
923924
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.
924925

925-
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
926+
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
926927
debug_sharding: False # Prints model weights sharding info
927928

928929
# Checkpoint Structured logging

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,7 @@ class PipelineParallelism(BaseModel):
855855
)
856856
pipeline_fsdp_ag_once: bool = Field(False, description="If True, all-gather FSDP weights once per pipeline repeat.")
857857
scan_pipeline_iterations: bool = Field(True, description="Use jax.lax.scan over pipeline iterations.")
858+
scan_pipeline_repeats: bool = Field(True, description="Use jax.lax.scan over pipeline repeats.")
858859
scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.")
859860
set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.")
860861
set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.")

src/maxtext/layers/decoders.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -794,12 +794,9 @@ def __call__(
794794
model_mode,
795795
)
796796
if cfg.using_pipeline_parallelism:
797-
if cfg.pipeline_fsdp_ag_once:
798-
logical_partition_spec = self.pipeline_module.get_weight_sharding(
799-
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
800-
)
801-
else:
802-
logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
797+
logical_partition_spec = self.pipeline_module.get_weight_sharding(
798+
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
799+
)
803800
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
804801
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
805802
dense_layer = RemattedBlockLayers[0]
@@ -1087,6 +1084,13 @@ def __call__(
10871084

10881085
else:
10891086
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
1087+
logits = sharding.maybe_shard_with_logical(
1088+
logits,
1089+
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1090+
mesh=self.mesh,
1091+
shard_mode=self.config.shard_mode,
1092+
debug_sharding=self.config.debug_sharding,
1093+
)
10901094

10911095
# The API of the Decoder is now a tuple, providing both the main output
10921096
# and the raw hidden state needed for auxiliary tasks.

0 commit comments

Comments
 (0)