Skip to content

Commit df0487f

Browse files
Merge pull request #3429 from CIeNET-International:charlesli/fix_nnx_error
PiperOrigin-RevId: 885115210
2 parents 93e2feb + 8b7b89e commit df0487f

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxtext/models/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def __init__(
335335
if cfg.pure_nnx_decoder:
336336
self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs)
337337
else:
338-
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
339-
self.decoder = nnx_wrappers.ToNNX(self.decoder, rngs=rngs)
338+
decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
339+
self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs)
340340
self.hidden_states = None
341341

342342
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode)
@@ -529,7 +529,7 @@ def __call__(
529529
attention_metadata=attention_metadata,
530530
deepstack_visual_embeds=deepstack_visual_embeds,
531531
mutable=mutable_collections,
532-
) # pytype: disable=wrong-keyword-args
532+
) # pytype: disable=wrong-keyword-args
533533

534534
# Materialize hidden state when vocab tiling is enabled
535535
if self.config.num_vocab_tiling > 1:

0 commit comments

Comments
 (0)