Skip to content

Commit c6fcd0e

Browse files
committed
enable pp with batch split ds
1 parent 1ce4481 commit c6fcd0e

8 files changed

Lines changed: 351 additions & 351 deletions

File tree

src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ rope_truncate: True
5656
rope_attention_scaling: False
5757

5858
override_logical_axis_rules: True
59-
mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60-
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
59+
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6161
logical_axis_rules: [
6262
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63-
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6464
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6565
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
6666
['activation_norm_length', ['context']],
6767
['activation_heads', []],
68+
['activation_stage', 'stage'],
6869
['embed', ['fsdp']],
6970
['embed_no_exp', ['fsdp']],
7071
['q_lora', ['fsdp']],
7172
['kv_lora', ['fsdp']],
73+
['layers', 'stage'],
7274
['q_lora_up_proj', ['fsdp_transpose', 'expert']],
7375
['kv_lora_up_proj', ['fsdp_transpose', 'expert']],
7476
['q_heads', ['fsdp_transpose', 'expert']],

src/maxtext/configs/types.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2527,77 +2527,77 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25272527

25282528
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
25292529
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
2530-
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2531-
self.ici_parallelism = [
2532-
self.ici_diloco_parallelism,
2533-
self.ici_pipeline_parallelism,
2534-
self.ici_data_parallelism,
2535-
self.ici_fsdp_parallelism,
2536-
self.ici_fsdp_transpose_parallelism,
2537-
self.ici_sequence_parallelism,
2538-
self.ici_context_parallelism,
2539-
self.ici_context_autoregressive_parallelism,
2540-
self.ici_tensor_parallelism,
2541-
self.ici_tensor_transpose_parallelism,
2542-
self.ici_tensor_sequence_parallelism,
2543-
self.ici_expert_parallelism,
2544-
self.ici_autoregressive_parallelism,
2545-
]
2546-
self.dcn_parallelism = [
2547-
self.dcn_diloco_parallelism,
2548-
self.dcn_pipeline_parallelism,
2549-
self.dcn_data_parallelism,
2550-
self.dcn_fsdp_parallelism,
2551-
self.dcn_fsdp_transpose_parallelism,
2552-
self.dcn_sequence_parallelism,
2553-
self.dcn_context_parallelism,
2554-
self.dcn_context_autoregressive_parallelism,
2555-
self.dcn_tensor_parallelism,
2556-
self.dcn_tensor_transpose_parallelism,
2557-
self.dcn_tensor_sequence_parallelism,
2558-
self.dcn_expert_parallelism,
2559-
self.dcn_autoregressive_parallelism,
2560-
]
2561-
else:
2562-
ici_map = {
2563-
"diloco": self.ici_diloco_parallelism,
2564-
"data": self.ici_data_parallelism,
2565-
"stage": self.ici_pipeline_parallelism,
2566-
"fsdp": self.ici_fsdp_parallelism,
2567-
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2568-
"sequence": self.ici_sequence_parallelism,
2569-
"context": self.ici_context_parallelism,
2570-
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2571-
"tensor": self.ici_tensor_parallelism,
2572-
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2573-
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2574-
"model": self.ici_tensor_parallelism,
2575-
"expert": self.ici_expert_parallelism,
2576-
"autoregressive": self.ici_autoregressive_parallelism,
2577-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2530+
# if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
2531+
# self.ici_parallelism = [
2532+
# self.ici_diloco_parallelism,
2533+
# self.ici_pipeline_parallelism,
2534+
# self.ici_data_parallelism,
2535+
# self.ici_fsdp_parallelism,
2536+
# self.ici_fsdp_transpose_parallelism,
2537+
# self.ici_sequence_parallelism,
2538+
# self.ici_context_parallelism,
2539+
# self.ici_context_autoregressive_parallelism,
2540+
# self.ici_tensor_parallelism,
2541+
# self.ici_tensor_transpose_parallelism,
2542+
# self.ici_tensor_sequence_parallelism,
2543+
# self.ici_expert_parallelism,
2544+
# self.ici_autoregressive_parallelism,
2545+
# ]
2546+
# self.dcn_parallelism = [
2547+
# self.dcn_diloco_parallelism,
2548+
# self.dcn_pipeline_parallelism,
2549+
# self.dcn_data_parallelism,
2550+
# self.dcn_fsdp_parallelism,
2551+
# self.dcn_fsdp_transpose_parallelism,
2552+
# self.dcn_sequence_parallelism,
2553+
# self.dcn_context_parallelism,
2554+
# self.dcn_context_autoregressive_parallelism,
2555+
# self.dcn_tensor_parallelism,
2556+
# self.dcn_tensor_transpose_parallelism,
2557+
# self.dcn_tensor_sequence_parallelism,
2558+
# self.dcn_expert_parallelism,
2559+
# self.dcn_autoregressive_parallelism,
2560+
# ]
2561+
# else:
2562+
ici_map = {
2563+
"diloco": self.ici_diloco_parallelism,
2564+
"data": self.ici_data_parallelism,
2565+
"stage": self.ici_pipeline_parallelism,
2566+
"fsdp": self.ici_fsdp_parallelism,
2567+
"fsdp_transpose": self.ici_fsdp_transpose_parallelism,
2568+
"sequence": self.ici_sequence_parallelism,
2569+
"context": self.ici_context_parallelism,
2570+
"context_autoregressive": self.ici_context_autoregressive_parallelism,
2571+
"tensor": self.ici_tensor_parallelism,
2572+
"tensor_transpose": self.ici_tensor_transpose_parallelism,
2573+
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2574+
"model": self.ici_tensor_parallelism,
2575+
"expert": self.ici_expert_parallelism,
2576+
"autoregressive": self.ici_autoregressive_parallelism,
2577+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
25782578
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
2579-
}
2580-
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2581-
2582-
dcn_map = {
2583-
"diloco": self.dcn_diloco_parallelism,
2584-
"data": self.dcn_data_parallelism,
2585-
"stage": self.dcn_pipeline_parallelism,
2586-
"fsdp": self.dcn_fsdp_parallelism,
2587-
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2588-
"sequence": self.dcn_sequence_parallelism,
2589-
"context": self.dcn_context_parallelism,
2590-
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2591-
"tensor": self.dcn_tensor_parallelism,
2592-
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2593-
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2594-
"model": self.dcn_tensor_parallelism,
2595-
"expert": self.dcn_expert_parallelism,
2596-
"autoregressive": self.dcn_autoregressive_parallelism,
2597-
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2579+
}
2580+
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
2581+
2582+
dcn_map = {
2583+
"diloco": self.dcn_diloco_parallelism,
2584+
"data": self.dcn_data_parallelism,
2585+
"stage": self.dcn_pipeline_parallelism,
2586+
"fsdp": self.dcn_fsdp_parallelism,
2587+
"fsdp_transpose": self.dcn_fsdp_transpose_parallelism,
2588+
"sequence": self.dcn_sequence_parallelism,
2589+
"context": self.dcn_context_parallelism,
2590+
"context_autoregressive": self.dcn_context_autoregressive_parallelism,
2591+
"tensor": self.dcn_tensor_parallelism,
2592+
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
2593+
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2594+
"model": self.dcn_tensor_parallelism,
2595+
"expert": self.dcn_expert_parallelism,
2596+
"autoregressive": self.dcn_autoregressive_parallelism,
2597+
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
25982598
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
2599-
}
2600-
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
2599+
}
2600+
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
26012601

26022602
# Diloco params
26032603
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)

src/maxtext/layers/attention_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,9 @@ def maybe_create_nnx(einsum, *args):
580580

581581
def _logical_to_mesh_axes(self, logical_name):
582582
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
583-
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
583+
return logical_to_mesh_axes(
584+
logical_name, mesh=self.mesh, rules=logical_rules
585+
)
584586

585587
def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None:
586588
"""Check attention inputs."""

src/maxtext/layers/decoders.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,11 +1085,11 @@ def __call__(
10851085
else:
10861086
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
10871087
logits = sharding.maybe_shard_with_logical(
1088-
logits,
1089-
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1090-
mesh=self.mesh,
1091-
shard_mode=self.config.shard_mode,
1092-
debug_sharding=self.config.debug_sharding,
1088+
logits,
1089+
("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab"),
1090+
mesh=self.mesh,
1091+
shard_mode=self.config.shard_mode,
1092+
debug_sharding=self.config.debug_sharding,
10931093
)
10941094

10951095
# The API of the Decoder is now a tuple, providing both the main output

0 commit comments

Comments
 (0)