Skip to content

Commit 92e1117

Browse files
committed
refactor pipeline code
1 parent 76b193b commit 92e1117

7 files changed

Lines changed: 1837 additions & 512 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2526,39 +2526,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25262526
raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.")
25272527

25282528
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
2529-
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
2530-
# if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2531-
# self.ici_parallelism = [
2532-
# self.ici_diloco_parallelism,
2533-
# self.ici_pipeline_parallelism,
2534-
# self.ici_data_parallelism,
2535-
# self.ici_fsdp_parallelism,
2536-
# self.ici_fsdp_transpose_parallelism,
2537-
# self.ici_sequence_parallelism,
2538-
# self.ici_context_parallelism,
2539-
# self.ici_context_autoregressive_parallelism,
2540-
# self.ici_tensor_parallelism,
2541-
# self.ici_tensor_transpose_parallelism,
2542-
# self.ici_tensor_sequence_parallelism,
2543-
# self.ici_expert_parallelism,
2544-
# self.ici_autoregressive_parallelism,
2545-
# ]
2546-
# self.dcn_parallelism = [
2547-
# self.dcn_diloco_parallelism,
2548-
# self.dcn_pipeline_parallelism,
2549-
# self.dcn_data_parallelism,
2550-
# self.dcn_fsdp_parallelism,
2551-
# self.dcn_fsdp_transpose_parallelism,
2552-
# self.dcn_sequence_parallelism,
2553-
# self.dcn_context_parallelism,
2554-
# self.dcn_context_autoregressive_parallelism,
2555-
# self.dcn_tensor_parallelism,
2556-
# self.dcn_tensor_transpose_parallelism,
2557-
# self.dcn_tensor_sequence_parallelism,
2558-
# self.dcn_expert_parallelism,
2559-
# self.dcn_autoregressive_parallelism,
2560-
# ]
2561-
# else:
25622529
ici_map = {
25632530
"diloco": self.ici_diloco_parallelism,
25642531
"data": self.ici_data_parallelism,

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,9 +580,7 @@ def maybe_create_nnx(einsum, *args):
580580

581581
def _logical_to_mesh_axes(self, logical_name):
582582
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
583-
return logical_to_mesh_axes(
584-
logical_name, mesh=self.mesh, rules=logical_rules
585-
)
583+
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
586584

587585
def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None:
588586
"""Check attention inputs."""

src/maxtext/layers/decoders.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def setup(self):
307307
if self.config.using_pipeline_parallelism:
308308
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
309309
remat_policy = self.get_remat_policy()
310-
self.pipeline_module = pipeline.Pipeline(
310+
self.pipeline_module = pipeline.create_pipeline(
311311
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
312312
)
313313

@@ -794,8 +794,10 @@ def __call__(
794794
model_mode,
795795
)
796796
if cfg.using_pipeline_parallelism:
797-
logical_partition_spec = self.pipeline_module.get_weight_sharding(
798-
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
797+
logical_partition_spec = (
798+
self.pipeline_module.get_weight_sharding(y, decoder_segment_ids, decoder_positions, deterministic, model_mode)
799+
if cfg.quantization == ""
800+
else None
799801
)
800802
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
801803
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."

0 commit comments

Comments
 (0)