Skip to content

Commit 102af23

Browse files
Merge pull request #3141 from AI-Hypercomputer:mohit/attn_expert_submit
PiperOrigin-RevId: 879321135
2 parents d14f70d + 0b7666f commit 102af23

7 files changed

Lines changed: 76 additions & 36 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,6 @@ dmypy.json
148148
# Gemini CLI
149149
.gemini/
150150
gha-creds-*.json
151+
152+
# vscode workspace
153+
maxtext.code-workspace

src/maxtext/configs/inference/vllm.yml

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ weight_dtype: bfloat16
2525

2626

2727
# -------------- Logical Axis Rules --------------
28-
mesh_axes: ['data', 'attn_dp', 'model', 'expert']
28+
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
2929
logical_axis_rules: [
3030
['activation_batch', ['expert']],
3131
['activation_batch_no_exp', []],
@@ -37,37 +37,38 @@ logical_axis_rules: [
3737
['activation_attn_length_no_exp', []],
3838
['activation_length', ['data', 'expert']],
3939
['activation_length_no_exp', 'data'],
40-
['activation_q_length', ['expert']],
40+
['activation_q_length', ['expert', 'attn_dp_expert']],
4141
['activation_attn_embed', 'model'],
4242
['activation_embed', ['model', 'attn_dp']],
4343
['activation_mlp', ['model', 'attn_dp']],
4444
['activation_kv', ['model']],
45-
['activation_prefill_kv_batch', ['expert']],
46-
['activation_kv_batch', ['expert']],
45+
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
46+
['activation_kv_batch', ['expert', 'attn_dp_expert']],
4747
['activation_kv_batch_no_exp', []],
4848
['activation_kv_head_dim', ['model']],
4949
['activation_vocab', ['model', 'attn_dp']],
5050
['activation_norm_length', []],
51-
['activation_exp', ['expert']],
52-
['decode_batch', ['expert']],
51+
['activation_exp', ['expert', 'attn_dp_expert']],
52+
['decode_batch', ['expert', 'attn_dp_expert']],
5353
['decode_length', []],
5454
['mlp', ['model', 'attn_dp']],
5555
['mlp_no_fsdp', ['model', 'attn_dp']],
56+
['moe_mlp', ['model', 'attn_dp']],
5657
['vocab', ['model', 'attn_dp']],
5758
['heads', ['model']],
5859
['q_heads', ['model']],
5960
['kv_heads', ['model']],
6061
['kv_head_dim', []],
6162
['kv', []],
62-
['embed', ['expert']],
63+
['embed', ['expert', 'attn_dp_expert']],
6364
['embed_tensor_transpose', ['attn_dp', 'model']],
6465
['embed_no_exp', []],
65-
['q_lora', ['expert']],
66-
['kv_lora', ['expert']],
66+
['q_lora', ['expert', 'attn_dp_expert']],
67+
['kv_lora', ['expert', 'attn_dp_expert']],
6768
['norm', []],
6869
['cache_heads', ['model']],
69-
['exp', ['expert']],
70+
['exp', ['expert', 'attn_dp_expert']],
7071
['paged_kv_heads', ['model']],
7172
]
72-
data_sharding: [['data', 'attn_dp', 'model', 'expert']]
73+
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
7374
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ num_samplers_slices: -1
2727
# replicas in rollout. If not specified, rollout_tensor_parallelism will be auto-determined.
2828
rollout_data_parallelism: -1
2929
rollout_tensor_parallelism: -1
30+
rollout_expert_parallelism: 1
3031

3132
# ====== Reproducibility ======
3233
data_shuffle_seed: 42

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,6 +1585,7 @@ class RLHardware(BaseModel):
15851585
-1,
15861586
description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined.",
15871587
)
1588+
rollout_expert_parallelism: int = Field(1, description="Expert parallelism per replica for rollout")
15881589

15891590

15901591
class VLLM(BaseModel):
@@ -2573,6 +2574,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25732574
"expert": self.ici_expert_parallelism,
25742575
"autoregressive": self.ici_autoregressive_parallelism,
25752576
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2577+
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
25762578
}
25772579
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
25782580

@@ -2592,6 +2594,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25922594
"expert": self.dcn_expert_parallelism,
25932595
"autoregressive": self.dcn_autoregressive_parallelism,
25942596
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2597+
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
25952598
}
25962599
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
25972600

src/maxtext/inference/vllm_decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def decode_with_vllm(config: Config) -> None:
7575
"hf_config_path": config.vllm_hf_config_path,
7676
"hf_overrides": config.vllm_hf_overrides,
7777
"gpu_memory_utilization": config.hbm_utilization_vllm,
78+
"async_scheduling": config.async_scheduling,
7879
"additional_config": {
7980
"maxtext_config": {
8081
"model_name": config.model_name,

src/maxtext/layers/moe.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,11 @@ def __init__(
368368
else:
369369
self._tensor_parallelism_name = "tensor"
370370

371+
if self.config.attention == "vllm_rpa":
372+
self._expert_parallelism_name = "attn_dp_expert"
373+
else:
374+
self._expert_parallelism_name = "expert"
375+
371376
self.gate = GateLogit(
372377
in_features_shape=self.config.emb_dim,
373378
out_features_shape=self.num_experts,
@@ -467,7 +472,7 @@ def _logical_to_mesh_axes(self, logical_name):
467472
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
468473

469474
def get_expert_parallelism_size(self):
470-
return self.mesh.shape.get("expert", 1)
475+
return self.mesh.shape.get(self._expert_parallelism_name, 1)
471476

472477
def get_tensor_parallelism_size(self):
473478
if isinstance(self._tensor_parallelism_name, tuple):
@@ -494,8 +499,8 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None):
494499
if self.config.use_random_routing:
495500
if rngs is None:
496501
raise ValueError("The random key cannot be None for random routing.")
497-
# Reuse the 'dropout' RNG stream to ensure random routing
498-
rng = rngs.dropout()
502+
# Reuse the 'params' RNG stream to ensure random routing
503+
rng = rngs.params()
499504
top_k_weights, top_k_indices = random_routing(rng, gate_logits, self.num_experts_per_tok)
500505
return top_k_weights, top_k_indices
501506

@@ -1002,7 +1007,7 @@ def gmm(
10021007
# batch_size=1 while decode can have batch_size > 1.
10031008
try:
10041009
is_batch_sharded_by_expert = (
1005-
"expert"
1010+
self._expert_parallelism_name
10061011
in tuple(
10071012
filter(
10081013
lambda tup: tup[0] == "activation_batch",
@@ -1094,10 +1099,9 @@ def gmm(
10941099
)
10951100
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
10961101
batch_size, sequence_length, _ = x.shape
1097-
expert_axis_name = "expert"
10981102
num_expert_parallelism = self.get_expert_parallelism_size()
10991103
if num_expert_parallelism > 1:
1100-
expert_shard_id = jax.lax.axis_index(expert_axis_name)
1104+
expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name)
11011105
else:
11021106
expert_shard_id = 0
11031107
num_expert_parallelism = self.get_expert_parallelism_size()
@@ -1107,7 +1111,8 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11071111

11081112
# Duplicate inputs to all expert shards.
11091113
x, logits, pre_bias_logits = tuple(
1110-
jax.lax.all_gather(z, axis_name=expert_axis_name, tiled=True) for z in (x, logits, pre_bias_logits)
1114+
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True)
1115+
for z in (x, logits, pre_bias_logits)
11111116
)
11121117

11131118
# "Route" tokens within each shard.
@@ -1131,7 +1136,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11311136
)
11321137

11331138
if num_expert_parallelism > 1:
1134-
batch_axis = "expert" if is_batch_sharded_by_expert else "data"
1139+
batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data"
11351140
# get group sizes for all shards
11361141
local_expert_size = self.config.num_experts // num_expert_parallelism
11371142
reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1)
@@ -1163,9 +1168,9 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11631168
send_sizes,
11641169
output_offsets,
11651170
recv_sizes,
1166-
axis_name=expert_axis_name,
1171+
axis_name=self._expert_parallelism_name,
11671172
)
1168-
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=expert_axis_name)
1173+
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name)
11691174
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
11701175
x,
11711176
global_group_sizes,
@@ -1310,7 +1315,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13101315

13111316
# Sum up the partial outputs across the expert shards.
13121317
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim))
1313-
output = jax.lax.psum_scatter(output, expert_axis_name, scatter_dimension=0, tiled=True)
1318+
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
13141319

13151320
else:
13161321
if num_expert_parallelism > 1:
@@ -1343,7 +1348,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13431348
send_sizes,
13441349
output_offsets,
13451350
recv_sizes,
1346-
axis_name=expert_axis_name,
1351+
axis_name=self._expert_parallelism_name,
13471352
)
13481353
else:
13491354
# If bach is replicated across EP shards then each shard should send
@@ -1363,7 +1368,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13631368
send_sizes,
13641369
output_offsets,
13651370
recv_sizes,
1366-
axis_name=expert_axis_name,
1371+
axis_name=self._expert_parallelism_name,
13671372
)
13681373

13691374
output = self.unpermute(

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -228,31 +228,56 @@ def setup_configs_and_devices(argv: list[str]):
228228
return trainer_config, sampler_config, trainer_devices, sampler_devices
229229

230230

231-
def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices):
231+
def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
232232
"""Get rollout kwargs for vLLM rollout when using data parallelism."""
233233
dp = sampler_config.rollout_data_parallelism
234-
if dp == -1:
235-
return {}
236-
237-
rollout_kwargs = {}
238234
tp = sampler_config.rollout_tensor_parallelism
235+
ep = sampler_config.rollout_expert_parallelism
236+
237+
# -1 means "auto-derive from the other two". At most one can be -1.
238+
num_auto = sum(1 for x in [tp, dp, ep] if x == -1)
239+
if num_auto > 1:
240+
raise ValueError(
241+
"At most one of rollout_tensor_parallelism, rollout_data_parallelism, "
242+
"rollout_expert_parallelism can be -1 (auto-derived)."
243+
)
239244

240-
if tp == -1:
241-
if num_sampler_devices % dp != 0:
245+
if dp == -1:
246+
if num_sampler_devices % (tp * ep) != 0:
242247
raise ValueError(
243248
f"num_sampler_devices({num_sampler_devices}) must be divisible by "
244-
f"rollout_data_parallelism({dp}) "
249+
f"rollout_tensor_parallelism({tp}) * rollout_expert_parallelism({ep}) "
250+
f"when rollout_data_parallelism is -1."
251+
)
252+
dp = num_sampler_devices // tp // ep
253+
elif tp == -1:
254+
if num_sampler_devices % (dp * ep) != 0:
255+
raise ValueError(
256+
f"num_sampler_devices({num_sampler_devices}) must be divisible by "
257+
f"rollout_data_parallelism({dp}) * rollout_expert_parallelism({ep}) "
245258
f"when rollout_tensor_parallelism is -1."
246259
)
247-
tp = num_sampler_devices // dp
248-
elif tp * dp != num_sampler_devices:
260+
tp = num_sampler_devices // dp // ep
261+
elif ep == -1:
262+
if num_sampler_devices % (tp * dp) != 0:
263+
raise ValueError(
264+
f"num_sampler_devices({num_sampler_devices}) must be divisible by "
265+
f"rollout_tensor_parallelism({tp}) * rollout_data_parallelism({dp}) "
266+
f"when rollout_expert_parallelism is -1."
267+
)
268+
ep = num_sampler_devices // tp // dp
269+
elif tp * dp * ep != num_sampler_devices:
249270
raise ValueError(
250271
f"rollout_tensor_parallelism({tp}) * "
251-
f"rollout_data_parallelism({dp}) "
272+
f"rollout_data_parallelism({dp}) * "
273+
f"rollout_expert_parallelism({ep}) "
252274
f"!= len(sampler_devices)({num_sampler_devices})"
253275
)
276+
277+
rollout_kwargs = {}
254278
rollout_kwargs["tensor_parallel_size"] = tp
255279
rollout_kwargs["data_parallel_size"] = dp
280+
rollout_kwargs["expert_parallel_size"] = ep
256281

257282
return rollout_kwargs
258283

@@ -544,13 +569,14 @@ def _filter_long_prompts(x):
544569
rollout_vllm_async_scheduling=trainer_config.async_scheduling,
545570
rollout_vllm_kwargs={
546571
"hf_overrides": trainer_config.vllm_hf_overrides,
572+
"enable_expert_parallel": sampler_config.rollout_expert_parallelism > 1,
547573
},
548574
rollout_vllm_sampling_kwargs={
549575
"stop": trainer_config.stop_strings,
550576
"detokenize": trainer_config.stop_strings is not None,
551577
"include_stop_str_in_output": trainer_config.stop_strings is not None,
552578
},
553-
**get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)),
579+
**get_rollout_kwargs_for_parallelism(sampler_config, len(sampler_devices)),
554580
),
555581
)
556582
grpo_config = GrpoConfig(

0 commit comments

Comments
 (0)