Skip to content

Commit 46218db

Browse files
committed
Update loss free & loss contrl load balance
1 parent 08216c6 commit 46218db

13 files changed

Lines changed: 215 additions & 43 deletions

File tree

docs/reference/core_concepts/moe_configuration.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ Dropping:
5959

6060
`routed_bias`: If enabled, adds a learnable bias term to the gate logits to facilitate load balancing.
6161

62+
`routed_bias_update_rate`: Defines the update rate to routed bias term above. Applicable only to the DeepSeek decoder block.
63+
6264
`routed_score_func`: Defines the scoring function for the router.
6365

6466
`routed_scaling_factor`: A scalar multiplier applied to the expert weights.

src/MaxText/configs/base.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ num_experts_per_tok: 1
177177
megablox: true
178178
sparse_matmul: true
179179
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
180-
load_balance_loss_weight: 0.01 # weight for the load balance loss
180+
load_balance_loss_weight: 0.0 # weight for the load balance loss
181181
use_random_routing: false # whether to use random routing for debug/test purpose
182182
use_custom_sort_vjp: true # whether to use a custom sort vjp for sparse matmul ops
183183
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
@@ -224,6 +224,7 @@ shared_experts: 1
224224
routed_scaling_factor: 1.0 # scaling factor for routing scores
225225
routed_score_func: "" # scoring function for routing
226226
routed_bias: False # a flag if a learnable bias is added for routing
227+
routed_bias_update_rate: 0.0 # a flag indicate the update rate applied to the router bias term
227228
mlp_bias: False # a flag if a learnable bias is added for MLP matmul
228229
n_routing_groups: -1 # number of groups for routing, disabled by default
229230
topk_routing_group: -1 # number of top groups to route inputs. For EP,

src/MaxText/configs/types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ class MoEGeneral(BaseModel):
552552
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
553553
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
554554
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
555-
load_balance_loss_weight: NonNegativeFloat = Field(0.01, description="Weight for the load balancing auxiliary loss.")
555+
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
556556
use_custom_sort_vjp: bool = Field(True, description="Whether to use a custom sort VJP for sparse matmul ops.")
557557
use_ring_of_experts: bool = Field(
558558
False,
@@ -639,6 +639,7 @@ class DeepSeekMoE(BaseModel):
639639
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")
640640
routed_score_func: str = Field("", description="Scoring function for routing (e.g., 'softmax', 'sigmoid').")
641641
routed_bias: bool = Field(False, description="Whether to add a bias term for routing.")
642+
routed_bias_update_rate: float = Field(0.0, description="Update rate applied to the router bias term.")
642643
mlp_bias: bool = Field(False, description="Whether to add a learnable bias for MLP matmul.")
643644
n_routing_groups: int = Field(-1, description="Number of groups for routing, disabled by default.")
644645
topk_routing_group: int = Field(-1, description="Number of top groups to route inputs to.")
@@ -2043,6 +2044,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
20432044
)
20442045
if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1:
20452046
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
2047+
if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK:
2048+
raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.")
20462049
if self.use_multimodal:
20472050
valid_mm_models = (
20482051
"gemma3-4b",

src/MaxText/layers/deepseek.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,14 @@ def self_attention_with_norm(
138138
return hidden_states, intermediate_inputs
139139

140140

141-
def post_process(cfg, layer_output, sow, kv_cache=None):
141+
def post_process(cfg, layer_output, load_balance_loss, moe_bias_updates, sow, kv_cache=None):
142142
"""postprocessing."""
143+
if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
144+
sow("intermediates", "moe_lb_loss", load_balance_loss)
145+
146+
if cfg.routed_bias and cfg.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
147+
sow("intermediates", "moe_bias_updates", moe_bias_updates)
148+
143149
if cfg.record_internal_nn_metrics:
144150
sow("intermediates", "activation_mean", jnp.mean(layer_output))
145151
sow("intermediates", "activation_stdev", jnp.std(layer_output))
@@ -233,7 +239,7 @@ def __call__(
233239
layer_output,
234240
logical_axis_names,
235241
)
236-
return post_process(cfg, layer_output, self.sow)
242+
return post_process(cfg, layer_output, None, None, self.sow)
237243

238244

239245
class DeepSeekMoELayer(nn.Module):
@@ -296,7 +302,7 @@ def __call__(
296302
# NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints.
297303
# The `name` represents the weight name in JAX/checkpoints and so the class name
298304
# is just for readability.
299-
mlp_lnx = moe.get_routed_and_shared_moe(
305+
mlp_lnx, load_balance_loss, moe_bias_updates = moe.get_routed_and_shared_moe(
300306
name="DeepSeekMoeBlock_0",
301307
config=cfg,
302308
mesh=self.mesh,
@@ -314,4 +320,4 @@ def __call__(
314320
layer_output,
315321
logical_axis_names,
316322
)
317-
return post_process(cfg, layer_output, self.sow)
323+
return post_process(cfg, layer_output, load_balance_loss, moe_bias_updates, self.sow)

src/MaxText/layers/gpt_oss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __call__(
182182
)
183183

184184
load_balance_loss = None
185-
mlp_lnx, load_balance_loss = self.GptOssMlp(hidden_states)
185+
mlp_lnx, load_balance_loss, _ = self.GptOssMlp(hidden_states)
186186
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
187187

188188
layer_output = mlp_lnx + intermediate_inputs
@@ -193,7 +193,7 @@ def __call__(
193193
("activation_batch", "activation_norm_length", "activation_embed"),
194194
)
195195

196-
if load_balance_loss is not None:
196+
if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
197197
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
198198

199199
if cfg.record_internal_nn_metrics:

src/MaxText/layers/llama4.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,9 @@ def __call__(
484484
hidden_states = self.post_self_attention_layer_norm(intermediate_inputs)
485485
hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names)
486486

487+
load_balance_loss = None
487488
if self.is_moe_layer:
488-
mlp_lnx = self.moe_block(hidden_states)
489+
mlp_lnx, load_balance_loss, _ = self.moe_block(hidden_states)
489490
else:
490491
mlp_lnx = self.mlp(hidden_states, deterministic=deterministic)
491492
mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names)
@@ -494,6 +495,9 @@ def __call__(
494495
layer_output = self.dropout(layer_output, deterministic=deterministic)
495496
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
496497

498+
if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
499+
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
500+
497501
if cfg.record_internal_nn_metrics:
498502
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
499503
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))

src/MaxText/layers/mixtral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,14 @@ def __call__(
172172
# NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints.
173173
# The `name` represents the weight name in JAX/checkpoints and so the class name
174174
# is just for readability.
175-
mlp_lnx, load_balance_loss = self.MoeBlock_0(hidden_states)
175+
mlp_lnx, load_balance_loss, _ = self.MoeBlock_0(hidden_states)
176176
mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names)
177177

178178
layer_output = mlp_lnx + intermediate_inputs
179179
layer_output = self.dropout(layer_output, deterministic=deterministic)
180180
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
181181

182-
if load_balance_loss is not None:
182+
if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
183183
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
184184

185185
if self.config.record_internal_nn_metrics:

src/MaxText/layers/moe.py

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from jax import ad_checkpoint as adc
2828
from jax.experimental import xla_metadata
2929
from jax.sharding import NamedSharding, Mesh
30+
from jax.sharding import PartitionSpec as P
3031
import jax.numpy as jnp
3132
from MaxText import common_types as ctypes
3233
from MaxText import max_logging
@@ -133,6 +134,30 @@ def random_routing(rng_key, gate_logits, num_experts_per_tok):
133134
return top_k_weights, top_k_indices
134135

135136

137+
def calculate_load_balance_updates(top_k_indices, num_experts, rate):
138+
"""
139+
Computes a bias adjustment update based on expert load.
140+
Used in DeepSeek V3: https://arxiv.org/html/2412.19437v1.
141+
Implementation reference: https://arxiv.org/pdf/2408.15664.
142+
143+
Args:
144+
top_k_indices: Shape (batch, sequence, top_k).
145+
num_experts: Total number of experts.
146+
rate: The update rate.
147+
148+
Returns:
149+
update: The value to add to the expert bias. Shape (num_experts,).
150+
"""
151+
flat_indices = top_k_indices.ravel()
152+
expert_counts = jnp.bincount(flat_indices, length=num_experts)
153+
154+
total_tokens = flat_indices.size
155+
average_load = total_tokens / num_experts
156+
direction = jnp.sign(average_load - expert_counts)
157+
output = direction * rate
158+
return output
159+
160+
136161
class GateLogit(nnx.Module):
137162
"""A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing."""
138163

@@ -436,6 +461,10 @@ def get_tensor_transpose_parallelism_size(self):
436461
def get_context_autoregressive_parallelism_size(self):
437462
return self.mesh.shape.get("context_autoregressive", 1)
438463

464+
def should_update_load_balance(self):
465+
"""Determines if loss-free load balancing updates should be applied."""
466+
return self.config.routed_bias and self.config.routed_bias_update_rate > 0.0
467+
439468
def get_topk(self, gate_logits, pre_bias_logits, rngs=None):
440469
"""get topk."""
441470
# shape of top_k_weights & top_k_indices:
@@ -560,6 +589,18 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
560589
inputs_2d = jnp.reshape(inputs, (bsz_times_seq_len, inputs_shape[2]))
561590
weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs)
562591

592+
lb_loss = None
593+
if self.config.load_balance_loss_weight > 0.0:
594+
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
595+
lb_loss = self.load_balance_loss(selected_experts, softmax_probs)
596+
597+
if self.should_update_load_balance():
598+
bias_updates = calculate_load_balance_updates(
599+
selected_experts, self.config.num_experts, self.config.routed_bias_update_rate
600+
)
601+
else:
602+
bias_updates = None
603+
563604
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
564605
# weights will be of shape (batch_size, seq_len, num_experts_per_tok)
565606
router_scores = jax.nn.sigmoid(weights.astype(jnp.float32)) # weights are top_k_weights here
@@ -589,6 +630,8 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
589630
weights,
590631
group_size,
591632
sorted_experts,
633+
lb_loss,
634+
bias_updates,
592635
)
593636

594637
def unpermute(
@@ -1010,9 +1053,13 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
10101053
w0_bias_pspec,
10111054
w1_bias_pspec,
10121055
wo_bias_pspec,
1013-
None,
1056+
P(), # Replicate the input key
1057+
),
1058+
out_specs=(
1059+
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")),
1060+
P(), # Handle None or replicate the output
1061+
P(), # Handle None or replicate the output
10141062
),
1015-
out_specs=(self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed"))),
10161063
check_vma=False,
10171064
)
10181065
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
@@ -1035,7 +1082,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10351082

10361083
# "Route" tokens within each shard.
10371084
num_experts_per_shard = self.config.num_experts // num_expert_parallelism
1038-
x, sorted_selected_experts, weights, group_sizes, selected_experts = self.permute(
1085+
x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute(
10391086
x,
10401087
logits,
10411088
pre_bias_logits,
@@ -1049,7 +1096,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10491096
mask = jnp.arange(x.shape[0]) < jnp.sum(group_sizes)
10501097
x = jnp.where(mask[:, None], x, 0)
10511098
else:
1052-
x, sorted_selected_experts, weights, group_sizes, selected_experts = self.permute(
1099+
x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute(
10531100
x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs
10541101
)
10551102

@@ -1264,7 +1311,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12641311
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
12651312
)
12661313

1267-
return output, None
1314+
return output, lb_loss, bias_updates
12681315

12691316
if self.config.moe_fsdp_use_two_stage_all_gather:
12701317
# Unshard on fsdp axis
@@ -1563,7 +1610,7 @@ def dense_matmul(
15631610
w0_bias,
15641611
w1_bias,
15651612
wo_bias,
1566-
) -> tuple[jax.Array, Optional[jax.Array]]:
1613+
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
15671614
"""Dense matrix multiplication."""
15681615
# gate_logits: batch, length, expert
15691616
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None))
@@ -1581,11 +1628,23 @@ def dense_matmul(
15811628
weights = self.reshape_and_update_weights(top_k_weights, top_k_indices)
15821629
matmul_precision = jax.lax.Precision(self.config.matmul_precision)
15831630

1631+
# Calculate load balance loss
15841632
if self.config.model_call_mode != "inference":
15851633
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
1586-
loss = self.load_balance_loss(top_k_indices, softmax_probs)
1634+
lb_loss = (
1635+
self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None
1636+
)
15871637
else:
1588-
loss = None
1638+
lb_loss = None
1639+
1640+
# Calculate routed bias updates (loss-free)
1641+
if self.should_update_load_balance():
1642+
bias_updates = calculate_load_balance_updates(
1643+
top_k_indices, self.config.num_experts, self.config.routed_bias_update_rate
1644+
)
1645+
else:
1646+
bias_updates = None
1647+
15891648
batch_size = inputs.shape[0]
15901649
seq_len = inputs.shape[1]
15911650

@@ -1783,7 +1842,7 @@ def dense_matmul(
17831842
output.shape[3],
17841843
),
17851844
)
1786-
return output, loss
1845+
return output, lb_loss, bias_updates
17871846
else:
17881847
inputs = self._maybe_shard_with_logical(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
17891848
with jax.named_scope("wi_0"):
@@ -1831,7 +1890,7 @@ def dense_matmul(
18311890
weights,
18321891
precision=matmul_precision,
18331892
).astype(self.dtype)
1834-
return output, None
1893+
return output, lb_loss, bias_updates
18351894

18361895
def retrieve_quantized_weight(
18371896
self,
@@ -1864,7 +1923,7 @@ def retrieve_quantized_weight(
18641923

18651924
def __call__(
18661925
self, inputs: jax.Array, out_sharding: NamedSharding | None = None
1867-
) -> tuple[jax.Array, Optional[jax.Array]]:
1926+
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
18681927
cfg = self.config
18691928
inputs = inputs.astype(cfg.dtype)
18701929
gate_logits, pre_bias_logits = self.gate(inputs)
@@ -1893,13 +1952,14 @@ def __call__(
18931952
w1_bias,
18941953
wo_bias,
18951954
)
1896-
return self.sparse_matmul(
1955+
output, lb_loss, bias_updates = self.sparse_matmul(
18971956
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
18981957
)
18991958
else:
1900-
return self.dense_matmul(
1959+
output, lb_loss, bias_updates = self.dense_matmul(
19011960
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
19021961
)
1962+
return output, lb_loss, bias_updates
19031963

19041964

19051965
class RoutedAndSharedMoE(nnx.Module):
@@ -1916,7 +1976,7 @@ def __init__(
19161976
dtype: ctypes.DType = jnp.float32,
19171977
quant: Optional[quantizations.AqtQuantization] = None,
19181978
):
1919-
"""nitializes the RoutedAndSharedMoE module.
1979+
"""Initializes the RoutedAndSharedMoE module.
19201980
19211981
Attributes:
19221982
config: The main config setting.
@@ -1973,10 +2033,10 @@ def __call__(
19732033
inputs: jax.Array,
19742034
intermediate_sharding: NamedSharding | None = None,
19752035
out_sharding: NamedSharding | None = None,
1976-
) -> jax.Array:
1977-
routed_experts, _ = self.routed_moe(inputs, out_sharding=out_sharding)
2036+
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
2037+
routed_experts, load_balance_loss, moe_bias_updates = self.routed_moe(inputs, out_sharding=out_sharding)
19782038
shared_experts = self.shared_experts(inputs, intermediate_sharding=intermediate_sharding, out_sharding=out_sharding)
1979-
return routed_experts + shared_experts
2039+
return routed_experts + shared_experts, load_balance_loss, moe_bias_updates
19802040

19812041

19822042
def get_gate_logit(

0 commit comments

Comments
 (0)