Skip to content

Commit c231424

Browse files
committed
add new pipeline weight prefetching config
1 parent 92e1117 commit c231424

8 files changed

Lines changed: 212 additions & 69 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ pipeline_parallel_layers: -1 # Pipeline only this number of layers - for the rem
275275
# PP degree divides the number of layers.
276276
# By default (when set to -1) we pipeline all of the decoder layers.
277277

278+
# Pipeline weight prefetching is an advanced SPMD pipeline parallelism improvement technique
279+
# When enabled, it prefetches necessary weight gathering ahead of microbatched computation, therefore reducing collectives
280+
use_pipeline_weight_prefetching: False
278281

279282
# num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages.
280283
# Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices
@@ -923,7 +926,7 @@ xprof_e2e_enable_fw_power_level_event: False
923926
xprof_e2e_enable_fw_thermal_event: False
924927
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.
925928

926-
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
929+
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
927930
debug_sharding: False # Prints model weights sharding info
928931

929932
# Checkpoint Structured logging

src/maxtext/configs/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,9 @@ class IciParallelism(BaseModel):
840840
class PipelineParallelism(BaseModel):
841841
"""Configuration for pipeline parallelism."""
842842

843+
use_pipeline_weight_prefetching: bool = Field(
844+
False, description="Enable weight prefetching for circular pipeline parallelism."
845+
)
843846
num_layers_per_pipeline_stage: int = Field(1, description="Number of layers to place on each pipeline stage.")
844847
num_pipeline_repeats: int = Field(
845848
-1,
@@ -2237,6 +2240,17 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22372240
)
22382241
self.num_pipeline_repeats = num_pipeline_repeats
22392242

2243+
if self.use_pipeline_weight_prefetching:
2244+
assert self.num_pipeline_repeats > 1, "Pipeline weight prefetching only supports circular pipeline."
2245+
assert (
2246+
self.num_layers_per_pipeline_stage == 1
2247+
), "Pipeline weight prefetching currently only supports one layer per pipeline stage."
2248+
assert (
2249+
not self.pipeline_delay_activation_forwarding
2250+
), "Pipeline weight prefetching does not support pipeline delay."
2251+
assert not self.quantization, "Quantization is currently not supported for pipeline prefetching."
2252+
assert not self.scan_layers_per_stage, "Pipeline weight prefetching currently does not support scan."
2253+
22402254
assert (num_stages * self.num_pipeline_repeats * self.num_layers_per_pipeline_stage) == (
22412255
self.pipeline_parallel_layers
22422256
), (

src/maxtext/layers/decoders.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def __call__(
796796
if cfg.using_pipeline_parallelism:
797797
logical_partition_spec = (
798798
self.pipeline_module.get_weight_sharding(y, decoder_segment_ids, decoder_positions, deterministic, model_mode)
799-
if cfg.quantization == ""
799+
if cfg.pipeline_fsdp_ag_once or cfg.use_pipeline_weight_prefetching
800800
else None
801801
)
802802
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
@@ -1086,13 +1086,6 @@ def __call__(
10861086

10871087
else:
10881088
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
1089-
logits = sharding.maybe_shard_with_logical(
1090-
logits,
1091-
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1092-
mesh=self.mesh,
1093-
shard_mode=self.config.shard_mode,
1094-
debug_sharding=self.config.debug_sharding,
1095-
)
10961089

10971090
# The API of the Decoder is now a tuple, providing both the main output
10981091
# and the raw hidden state needed for auxiliary tasks.

src/maxtext/layers/pipeline.py

Lines changed: 139 additions & 54 deletions
Large diffs are not rendered by default.

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ def gmm(
809809
group_sizes,
810810
representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
811811
)
812-
if config.use_qwix_quantization or config.using_pipeline_parallelism:
812+
if config.use_qwix_quantization or (config.using_pipeline_parallelism and config.use_pipeline_weight_prefetching):
813813
output = megablox.gmm(
814814
lhs=inputs,
815815
rhs=kernel,

src/maxtext/utils/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def create_run_scannable(
255255
"""Creates a scannable function for pipeline loop iterations."""
256256

257257
def run_scannable(model, loop_state):
258-
loop_state["bsw"] = model.bsw_all_gather_over_fsdp(
258+
loop_state["bsw"] = model.weight_prefetching(
259259
loop_state["weights"], physical_partition_spec, loop_state["loop_iteration"]
260260
)
261261

tests/unit/pipeline_parallelism_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,24 @@ def test_circular_ag_once(self):
278278
)
279279
self.assert_pipeline_same_output_and_grad(config)
280280

281+
@pytest.mark.tpu_only
282+
def test_circular_pipeline_prefetching(self):
283+
# 2 stages, 8 microbatches, enable pipeline weight prefetching
284+
config = pyconfig.initialize(
285+
[sys.argv[0], get_test_config_path()],
286+
enable_checkpointing=False,
287+
enable_goodput_recording=False,
288+
run_name="circular_prefetching",
289+
max_target_length=128,
290+
base_emb_dim=28,
291+
ici_pipeline_parallelism=2,
292+
base_num_decoder_layers=8,
293+
num_pipeline_microbatches=8,
294+
per_device_batch_size=4,
295+
use_pipeline_weight_prefetching=True,
296+
)
297+
self.assert_pipeline_same_output_and_grad(config)
298+
281299
@pytest.mark.tpu_only
282300
def test_non_circular_same_output_and_grad(self):
283301
# 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches
@@ -326,6 +344,40 @@ def test_full_train_circular(self):
326344
]
327345
)
328346

347+
@pytest.mark.integration_test
348+
@pytest.mark.tpu_only
349+
def test_full_train_circular_pipeline_prefetching(self):
350+
# Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats),
351+
# 8 microbatches and using pipeline weight prefetching
352+
train_main(
353+
[
354+
None,
355+
get_test_config_path(),
356+
f"base_output_directory={self.base_output_directory}",
357+
"run_name=runner_pipeline_parallelism_test",
358+
f"dataset_path={self.dataset_path}",
359+
"base_emb_dim=28",
360+
"base_num_query_heads=4",
361+
"base_num_kv_heads=4",
362+
"base_mlp_dim=32",
363+
"base_num_decoder_layers=32",
364+
"head_dim=128",
365+
"per_device_batch_size=2",
366+
"max_target_length=1024",
367+
"vocab_size=32",
368+
"dataset_type=synthetic",
369+
"steps=3",
370+
"enable_checkpointing=False",
371+
"enable_goodput_recording=False",
372+
"ici_pipeline_parallelism=2",
373+
"num_layers_per_pipeline_stage=1",
374+
"num_pipeline_microbatches=4",
375+
"use_pipeline_weight_prefetching=True",
376+
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
377+
"scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations.
378+
]
379+
)
380+
329381
@pytest.mark.tpu_only
330382
def test_delay_activation_forwarding_same_output_and_grad(self):
331383
# 4 stages, delayed activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches

tests/unit/train_compile_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,6 @@ def test_circular_pipeline_compile_pp_fsdp_fsdpt_ep_ds3(self):
860860
"use_random_routing=true",
861861
"allow_split_physical_axes=true",
862862
"max_target_length=4096",
863-
"remat_policy=custom",
864863
)
865864
)
866865

@@ -881,10 +880,8 @@ def test_circular_pipeline_compile_pp_fsdp_tp_ds3(self):
881880
"pipeline_parallel_layers=56",
882881
"num_pipeline_microbatches=16",
883882
"model_name=deepseek3-671b",
884-
"ici_expert_parallelism=4",
885883
"allow_split_physical_axes=true",
886884
"max_target_length=4096",
887-
"remat_policy=custom",
888885
)
889886
)
890887

@@ -910,7 +907,6 @@ def test_circular_pipeline_compile_pp_fsdp_tp_ep_ds3(self):
910907
"use_random_routing=false",
911908
"allow_split_physical_axes=true",
912909
"max_target_length=4096",
913-
"remat_policy=custom",
914910
)
915911
)
916912

0 commit comments

Comments
 (0)