Skip to content

Commit d252ddf

Browse files
mesakhcienetGoogle-ML-Automation
authored andcommitted
Copybara import of the project:
-- b6dc761 by mesakhcienet <mesakh.christian@cienet.com>: feat: migrate deepseek to nnx COPYBARA_INTEGRATE_REVIEW=#2658 from CIeNET-International:feat/migrate-deepseek-to-nnx b6dc761 PiperOrigin-RevId: 852094911
1 parent 98a3d4c commit d252ddf

5 files changed

Lines changed: 677 additions & 428 deletions

File tree

src/MaxText/layers/decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,9 @@ def get_decoder_layers(self):
404404
return [mixtral.MixtralDecoderLayerToLinen]
405405
case DecoderBlockType.DEEPSEEK:
406406
if self.config.use_batch_split_schedule:
407-
return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer]
407+
return [deepseek_batchsplit.DeepSeekDenseLayerToLinen, deepseek_batchsplit.DeepSeekMoELayerToLinen]
408408
else:
409-
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
409+
return [deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen]
410410
case DecoderBlockType.GEMMA:
411411
return [gemma.GemmaDecoderLayerToLinen]
412412
case DecoderBlockType.GEMMA2:

0 commit comments

Comments
 (0)