@@ -128,6 +128,40 @@ def __init__(
128128 model_mode = MODEL_MODE_TRAIN ,
129129 )
130130
131+
132+ @property
133+ def embedding_norm (self ):
134+ return getattr (self , f"mtp_{ self .layer_number } _embedding_norm" )
135+
136+ @embedding_norm .setter
137+ def embedding_norm (self , module ):
138+ setattr (self , f"mtp_{ self .layer_number } _embedding_norm" , module )
139+
140+ @property
141+ def hidden_state_norm (self ):
142+ return getattr (self , f"mtp_{ self .layer_number } _hidden_state_norm" )
143+
144+ @hidden_state_norm .setter
145+ def hidden_state_norm (self , module ):
146+ setattr (self , f"mtp_{ self .layer_number } _hidden_state_norm" , module )
147+
148+ @property
149+ def projection_layer (self ):
150+ return getattr (self , f"mtp_{ self .layer_number } _projection" )
151+
152+ @projection_layer .setter
153+ def projection_layer (self , module ):
154+ setattr (self , f"mtp_{ self .layer_number } _projection" , module )
155+
156+ @property
157+ def transformer_layer (self ):
158+ return getattr (self , f"mtp_{ self .layer_number } _transformer_layer" )
159+
160+ @transformer_layer .setter
161+ def transformer_layer (self , module ):
162+ setattr (self , f"mtp_{ self .layer_number } _transformer_layer" , module )
163+
164+
131165 def __call__ (
132166 self ,
133167 prev_hidden_state : jnp .ndarray ,
@@ -192,13 +226,6 @@ def __init__(
192226 self .decoder = decoder
193227 self .rngs = rngs if rngs is not None else nnx .Rngs (0 )
194228
195- # NNX Variables are exposed as Linen mutable collections by ToLinen wrapper.
196- self .losses = mtp_losses (jnp .zeros ((config .mtp_num_layers ,), dtype = jnp .float32 ))
197- self .weights = mtp_losses (jnp .zeros ((config .mtp_num_layers ,), dtype = jnp .float32 ))
198- # Float32 used to avoid gradient errors; converted to int32 in acceptance rate calculation.
199- self .mtp_preds = mtp_acceptance (jnp .zeros ((1 ,), dtype = jnp .float32 ))
200- self .mtp_mask = mtp_acceptance (jnp .zeros ((1 ,), dtype = jnp .float32 ))
201-
202229 # 1-indexed to match paper convention.
203230 for k in range (1 , config .mtp_num_layers + 1 ):
204231 layer = MultiTokenPredictionLayer (
@@ -278,11 +305,13 @@ def __call__(
278305 mtp_masks_list .append (rolled_target_mask )
279306
280307 if mtp_losses_list :
281- self .losses .value = jnp .stack (mtp_losses_list )
282- self .weights .value = jnp .stack (mtp_weights_list )
308+ # Not part of checkpoints, don't declare in __init__
309+ self .losses = mtp_losses (jnp .stack (mtp_losses_list ))
310+ self .weights = mtp_losses (jnp .stack (mtp_weights_list ))
283311 if mtp_preds_list :
284- self .mtp_preds .value = jnp .stack (mtp_preds_list )
285- self .mtp_mask .value = jnp .stack (mtp_masks_list )
312+ # Not part of checkpoints, don't declare in __init__
313+ self .mtp_preds = mtp_acceptance (jnp .stack (mtp_preds_list ))
314+ self .mtp_mask = mtp_acceptance (jnp .stack (mtp_masks_list ))
286315
287316 return {}
288317
0 commit comments