2727from jax import ad_checkpoint as adc
2828from jax .experimental import xla_metadata
2929from jax .sharding import NamedSharding , Mesh
30+ from jax .sharding import PartitionSpec as P
3031import jax .numpy as jnp
3132from MaxText import common_types as ctypes
3233from MaxText import max_logging
@@ -133,6 +134,30 @@ def random_routing(rng_key, gate_logits, num_experts_per_tok):
133134 return top_k_weights , top_k_indices
134135
135136
137+ def calculate_load_balance_updates (top_k_indices , num_experts , rate ):
138+ """
139+ Computes a bias adjustment update based on expert load.
140+ Used in DeepSeek V3: https://arxiv.org/html/2412.19437v1.
141+ Implementation reference: https://arxiv.org/pdf/2408.15664.
142+
143+ Args:
144+ top_k_indices: Shape (batch, sequence, top_k).
145+ num_experts: Total number of experts.
146+ rate: The update rate.
147+
148+ Returns:
149+ update: The value to add to the expert bias. Shape (num_experts,).
150+ """
151+ flat_indices = top_k_indices .ravel ()
152+ expert_counts = jnp .bincount (flat_indices , length = num_experts )
153+
154+ total_tokens = flat_indices .size
155+ average_load = total_tokens / num_experts
156+ direction = jnp .sign (average_load - expert_counts )
157+ output = direction * rate
158+ return output
159+
160+
136161class GateLogit (nnx .Module ):
137162 """A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing."""
138163
@@ -436,6 +461,10 @@ def get_tensor_transpose_parallelism_size(self):
436461 def get_context_autoregressive_parallelism_size (self ):
437462 return self .mesh .shape .get ("context_autoregressive" , 1 )
438463
464+ def should_update_load_balance (self ):
465+ """Determines if loss-free load balancing updates should be applied."""
466+ return self .config .routed_bias and self .config .routed_bias_update_rate > 0.0
467+
439468 def get_topk (self , gate_logits , pre_bias_logits , rngs = None ):
440469 """get topk."""
441470 # shape of top_k_weights & top_k_indices:
@@ -560,6 +589,18 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
560589 inputs_2d = jnp .reshape (inputs , (bsz_times_seq_len , inputs_shape [2 ]))
561590 weights , selected_experts = self .get_topk (gate_logits , pre_bias_logits , rngs )
562591
592+ lb_loss = None
593+ if self .config .load_balance_loss_weight > 0.0 :
594+ softmax_probs = jax .nn .softmax (gate_logits .astype (jnp .float32 ), axis = - 1 ).astype (self .dtype )
595+ lb_loss = self .load_balance_loss (selected_experts , softmax_probs )
596+
597+ if self .should_update_load_balance ():
598+ bias_updates = calculate_load_balance_updates (
599+ selected_experts , self .config .num_experts , self .config .routed_bias_update_rate
600+ )
601+ else :
602+ bias_updates = None
603+
563604 if self .config .decoder_block == ctypes .DecoderBlockType .LLAMA4 :
564605 # weights will be of shape (batch_size, seq_len, num_experts_per_tok)
565606 router_scores = jax .nn .sigmoid (weights .astype (jnp .float32 )) # weights are top_k_weights here
@@ -589,6 +630,8 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
589630 weights ,
590631 group_size ,
591632 sorted_experts ,
633+ lb_loss ,
634+ bias_updates ,
592635 )
593636
594637 def unpermute (
@@ -1010,9 +1053,13 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
10101053 w0_bias_pspec ,
10111054 w1_bias_pspec ,
10121055 wo_bias_pspec ,
1013- None ,
1056+ P (), # Replicate the input key
1057+ ),
1058+ out_specs = (
1059+ self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length" , "activation_embed" )),
1060+ P (), # Handle None or replicate the output
1061+ P (), # Handle None or replicate the output
10141062 ),
1015- out_specs = (self ._logical_to_mesh_axes ((batch_logical_axis , "activation_norm_length" , "activation_embed" ))),
10161063 check_vma = False ,
10171064 )
10181065 def wrapper (x , logits , pre_bias_logits , w0 , w1 , wo , w0_bias , w1_bias , wo_bias , rngs ):
@@ -1035,7 +1082,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10351082
10361083 # "Route" tokens within each shard.
10371084 num_experts_per_shard = self .config .num_experts // num_expert_parallelism
1038- x , sorted_selected_experts , weights , group_sizes , selected_experts = self .permute (
1085+ x , sorted_selected_experts , weights , group_sizes , selected_experts , lb_loss , bias_updates = self .permute (
10391086 x ,
10401087 logits ,
10411088 pre_bias_logits ,
@@ -1049,7 +1096,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10491096 mask = jnp .arange (x .shape [0 ]) < jnp .sum (group_sizes )
10501097 x = jnp .where (mask [:, None ], x , 0 )
10511098 else :
1052- x , sorted_selected_experts , weights , group_sizes , selected_experts = self .permute (
1099+ x , sorted_selected_experts , weights , group_sizes , selected_experts , lb_loss , bias_updates = self .permute (
10531100 x , logits , pre_bias_logits , self .config .use_custom_sort_vjp , rngs
10541101 )
10551102
@@ -1264,7 +1311,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12641311 use_custom_sort_vjp = self .config .use_custom_sort_vjp ,
12651312 )
12661313
1267- return output , None
1314+ return output , lb_loss , bias_updates
12681315
12691316 if self .config .moe_fsdp_use_two_stage_all_gather :
12701317 # Unshard on fsdp axis
@@ -1563,7 +1610,7 @@ def dense_matmul(
15631610 w0_bias ,
15641611 w1_bias ,
15651612 wo_bias ,
1566- ) -> tuple [jax .Array , Optional [jax .Array ]]:
1613+ ) -> tuple [jax .Array , Optional [jax .Array ], Optional [ jax . Array ] ]:
15671614 """Dense matrix multiplication."""
15681615 # gate_logits: batch, length, expert
15691616 gate_logits = self ._maybe_shard_with_logical (gate_logits , ("activation_batch" , "activation_norm_length" , None ))
@@ -1581,11 +1628,23 @@ def dense_matmul(
15811628 weights = self .reshape_and_update_weights (top_k_weights , top_k_indices )
15821629 matmul_precision = jax .lax .Precision (self .config .matmul_precision )
15831630
1631+ # Calculate load balance loss
15841632 if self .config .model_call_mode != "inference" :
15851633 softmax_probs = jax .nn .softmax (gate_logits .astype (jnp .float32 ), axis = - 1 ).astype (self .dtype )
1586- loss = self .load_balance_loss (top_k_indices , softmax_probs )
1634+ lb_loss = (
1635+ self .load_balance_loss (top_k_indices , softmax_probs ) if self .config .load_balance_loss_weight > 0.0 else None
1636+ )
15871637 else :
1588- loss = None
1638+ lb_loss = None
1639+
1640+ # Calculate routed bias updates (loss-free)
1641+ if self .should_update_load_balance ():
1642+ bias_updates = calculate_load_balance_updates (
1643+ top_k_indices , self .config .num_experts , self .config .routed_bias_update_rate
1644+ )
1645+ else :
1646+ bias_updates = None
1647+
15891648 batch_size = inputs .shape [0 ]
15901649 seq_len = inputs .shape [1 ]
15911650
@@ -1783,7 +1842,7 @@ def dense_matmul(
17831842 output .shape [3 ],
17841843 ),
17851844 )
1786- return output , loss
1845+ return output , lb_loss , bias_updates
17871846 else :
17881847 inputs = self ._maybe_shard_with_logical (inputs , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
17891848 with jax .named_scope ("wi_0" ):
@@ -1831,7 +1890,7 @@ def dense_matmul(
18311890 weights ,
18321891 precision = matmul_precision ,
18331892 ).astype (self .dtype )
1834- return output , None
1893+ return output , lb_loss , bias_updates
18351894
18361895 def retrieve_quantized_weight (
18371896 self ,
@@ -1864,7 +1923,7 @@ def retrieve_quantized_weight(
18641923
18651924 def __call__ (
18661925 self , inputs : jax .Array , out_sharding : NamedSharding | None = None
1867- ) -> tuple [jax .Array , Optional [jax .Array ]]:
1926+ ) -> tuple [jax .Array , Optional [jax .Array ], Optional [ jax . Array ] ]:
18681927 cfg = self .config
18691928 inputs = inputs .astype (cfg .dtype )
18701929 gate_logits , pre_bias_logits = self .gate (inputs )
@@ -1893,13 +1952,14 @@ def __call__(
18931952 w1_bias ,
18941953 wo_bias ,
18951954 )
1896- return self .sparse_matmul (
1955+ output , lb_loss , bias_updates = self .sparse_matmul (
18971956 inputs , gate_logits , pre_bias_logits , w0_kernel , w1_kernel , wo_kernel , w0_bias , w1_bias , wo_bias
18981957 )
18991958 else :
1900- return self .dense_matmul (
1959+ output , lb_loss , bias_updates = self .dense_matmul (
19011960 inputs , gate_logits , pre_bias_logits , w0_kernel , w1_kernel , wo_kernel , w0_bias , w1_bias , wo_bias
19021961 )
1962+ return output , lb_loss , bias_updates
19031963
19041964
19051965class RoutedAndSharedMoE (nnx .Module ):
@@ -1916,7 +1976,7 @@ def __init__(
19161976 dtype : ctypes .DType = jnp .float32 ,
19171977 quant : Optional [quantizations .AqtQuantization ] = None ,
19181978 ):
1919- """nitializes the RoutedAndSharedMoE module.
1979+ """Initializes the RoutedAndSharedMoE module.
19201980
19211981 Attributes:
19221982 config: The main config setting.
@@ -1973,10 +2033,10 @@ def __call__(
19732033 inputs : jax .Array ,
19742034 intermediate_sharding : NamedSharding | None = None ,
19752035 out_sharding : NamedSharding | None = None ,
1976- ) -> jax .Array :
1977- routed_experts , _ = self .routed_moe (inputs , out_sharding = out_sharding )
2036+ ) -> tuple [ jax .Array , Optional [ jax . Array ], Optional [ jax . Array ]] :
2037+ routed_experts , load_balance_loss , moe_bias_updates = self .routed_moe (inputs , out_sharding = out_sharding )
19782038 shared_experts = self .shared_experts (inputs , intermediate_sharding = intermediate_sharding , out_sharding = out_sharding )
1979- return routed_experts + shared_experts
2039+ return routed_experts + shared_experts , load_balance_loss , moe_bias_updates
19802040
19812041
19822042def get_gate_logit (
0 commit comments