Skip to content

Commit e8cbb57

Browse files
Merge pull request #3016 from AI-Hypercomputer:chengnuojin-fix-balance-mtp
PiperOrigin-RevId: 861940470
2 parents 9bb5959 + 5b867e3 commit e8cbb57

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/MaxText/layers/deepseek.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,10 @@ def __call__(
387387
return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache)
388388

389389
def mlp_op(self, x, deterministic, *args, **kwargs):
390-
return self.with_logical_constraint(
391-
self.DeepSeekMoeBlock_0(x, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding)
390+
mlp_lnx, load_balance_loss, moe_bias_updates = self.DeepSeekMoeBlock_0(
391+
x, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding
392392
)
393+
return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates
393394

394395

395396
DeepSeekMoELayerToLinen = nnx_wrappers.to_linen_class(

0 commit comments

Comments
 (0)