Skip to content

Commit 5b867e3

Browse files
committed
fix load balance sharding error
1 parent 941d46a commit 5b867e3

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/MaxText/layers/deepseek.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,10 @@ def __call__(
387387
return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache)
388388

389389
def mlp_op(self, x, deterministic, *args, **kwargs):
390-
return self.with_logical_constraint(
391-
self.DeepSeekMoeBlock_0(x, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding)
390+
mlp_lnx, load_balance_loss, moe_bias_updates = self.DeepSeekMoeBlock_0(
391+
x, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding
392392
)
393+
return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates
393394

394395

395396
DeepSeekMoELayerToLinen = nnx_wrappers.to_linen_class(

0 commit comments

Comments
 (0)