Skip to content

Commit e986f24

Browse files
committed
Update a sharding config
1 parent 4799cef commit e986f24

6 files changed

Lines changed: 15 additions & 15 deletions

File tree

docs/reference/core_concepts/moe_configuration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ Dropping:
9797

9898
`moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
9999

100-
`fsdp_shard_on_exp`: If enabled, shard the expert dimension of the MLP weights on the FSDP axis, and recommended when num_experts is a multiple of fsdp_parallelism.
100+
`shard_exp_on_fsdp`: If enabled, shard the expert dimension of the MLP weights on the FSDP axis, and recommended only when num_experts is a multiple of fsdp_parallelism.
101101

102102
## 3. Performance Tuning
103103
These parameters provide granular control over the tiling dimensions for sparse matmul Pallas kernel.

src/MaxText/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ expert_shard_attention_option: "fsdp"
214214
moe_fsdp_use_two_stage_all_gather: false
215215
# Shard the expert dimension of the MLP weights on the FSDP axis.
216216
# This configuration is recommended only when num_experts is a multiple of fsdp_parallelism
217-
fsdp_shard_on_exp: False
217+
shard_exp_on_fsdp: False
218218
# use fsdp and fsdp_transpose axes for sharding the moe weights
219219
use_2d_fsdp_sharding: False
220220

src/MaxText/configs/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,10 @@ class MoEGeneral(BaseModel):
570570
False,
571571
description="Use two separate All-Gather calls for MoE weights sharded on both FSDP and FSDP-transpose.",
572572
)
573-
fsdp_shard_on_exp: bool = Field(
573+
shard_exp_on_fsdp: bool = Field(
574574
False,
575575
description="Shard the expert dimension of the MLP weights on the FSDP axis, "
576-
"and recommended when num_experts is a multiple of fsdp_parallelism",
576+
"and recommended only when num_experts is a multiple of fsdp_parallelism",
577577
)
578578
use_2d_fsdp_sharding: bool = Field(
579579
False,

src/MaxText/layers/moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def __init__(
342342
self.quant = quant
343343
self.rngs = rngs
344344

345-
if self.config.fsdp_shard_on_exp:
345+
if self.config.shard_exp_on_fsdp:
346346
# special sharding for dsv3
347347
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
348348
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
@@ -1012,10 +1012,10 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
10121012
# w0, w1, wo needs to be un sharded on fsdp / fsdp_transpose axis, so use
10131013
# mlp_no_fsdp axis
10141014
weight_gather = False
1015-
if self.config.fsdp_shard_on_exp:
1015+
if self.config.shard_exp_on_fsdp:
10161016
quantization_rule = qpl.get_current_rule("gmm")
10171017
if quantization_rule and quantization_rule.weight_calibration_method.startswith("fixed"):
1018-
# special sharding when using static scaling for weights in quantization with fsdp_shard_on_exp
1018+
# special sharding when using static scaling for weights in quantization with shard_exp_on_fsdp
10191019
w0_pspec = self._logical_to_mesh_axes(self.wi_kernel_axes)
10201020
w1_pspec = self._logical_to_mesh_axes(self.wi_kernel_axes)
10211021
wo_pspec = self._logical_to_mesh_axes(self.wo_kernel_axes)

src/MaxText/pyconfig_deprecated.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def validate_keys(keys):
305305
validate_mlp_dim(keys)
306306
validate_sparse_matmul_parallelism(keys)
307307
validate_ring_of_experts_parallelism(keys)
308-
validate_shard_fsdp_on_expert_parallelism(keys)
308+
validate_shard_expert_on_fsdp(keys)
309309
validate_ragged_dot(keys)
310310
validate_deepseek_moe(keys)
311311
validate_gpt_oss_moe(keys)
@@ -1212,12 +1212,12 @@ def validate_ring_of_experts_parallelism(raw_keys):
12121212
raise ValueError("Ring-of-experts requires expert-parallelism to be enabled.")
12131213

12141214

1215-
def validate_shard_fsdp_on_expert_parallelism(raw_keys):
1216-
if raw_keys["fsdp_shard_on_exp"] and raw_keys["num_experts"] % raw_keys["ici_fsdp_parallelism"] != 0:
1217-
raise ValueError("fsdp_shard_on_exp requires num_experts is divisiable by ici_fsdp_parallelism.")
1218-
if raw_keys["fsdp_shard_on_exp"] and (using_tensor_parallelism(raw_keys) or using_expert_parallelism(raw_keys)):
1215+
def validate_shard_expert_on_fsdp(raw_keys):
1216+
if raw_keys["shard_exp_on_fsdp"] and raw_keys["num_experts"] % raw_keys["ici_fsdp_parallelism"] != 0:
1217+
raise ValueError("shard_exp_on_fsdp requires num_experts is divisiable by ici_fsdp_parallelism.")
1218+
if raw_keys["shard_exp_on_fsdp"] and (using_tensor_parallelism(raw_keys) or using_expert_parallelism(raw_keys)):
12191219
raise ValueError(
1220-
"fsdp_shard_on_exp requires ici_expert_parallelism = 1 and ici_tensor_parallelism/ici_tensor_transpose_parallelism = 1."
1220+
"shard_exp_on_fsdp requires ici_expert_parallelism = 1 and ici_tensor_parallelism/ici_tensor_transpose_parallelism = 1."
12211221
)
12221222

12231223

tests/check_qwen3_next_vs_reference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def setUp(self):
648648
"num_experts_per_tok=2",
649649
"base_moe_mlp_dim=256", # moe_mlp_dim will be calculated from this
650650
"norm_topk_prob=True",
651-
"fsdp_shard_on_exp=False",
651+
"shard_exp_on_fsdp=False",
652652
"mlp_activations=['silu', 'linear']",
653653
"dropout_rate=0.0",
654654
# Force the test to use the 'dense_matmul' path in the MoE layer,
@@ -1103,7 +1103,7 @@ def _run_full_attention_jax_vs_pytorch_attention(self, attention_type):
11031103
"num_experts_per_tok=2",
11041104
"base_moe_mlp_dim=256", # moe_mlp_dim will be calculated from this
11051105
"norm_topk_prob=True",
1106-
"fsdp_shard_on_exp=False",
1106+
"shard_exp_on_fsdp=False",
11071107
"mlp_activations=['silu', 'linear']",
11081108
"dropout_rate=0.0",
11091109
"sparse_matmul=False",

0 commit comments

Comments
 (0)