1919import jax
2020import jax .numpy as jnp
2121from jax .sharding import Mesh
22- from flax import linen as nn
22+ from flax import nnx
2323
2424from MaxText .common_types import Config
2525from MaxText import max_logging , pyconfig
2929from MaxText .layers import multi_token_prediction # The class under test
3030from MaxText .layers import embeddings
3131from MaxText .common_types import MODEL_MODE_TRAIN
32-
32+ from MaxText . layers import nnx_wrappers
3333
3434TEST_LAYER_NUM = 1
3535
@@ -122,29 +122,35 @@ def test_multi_token_prediction_layer_output(self):
122122
123123
124124# A lightweight wrapper model for robustly testing the MTPBlock.
125- class MTPBlockTestModel (nn .Module ):
125+ class MTPBlockTestModel (nnx .Module ):
126126 """A lightweight wrapper model for testing the MTPBlock."""
127127
128- config : Config
129- mesh : Mesh
130-
131- def setup (self ):
128+ def __init__ (
129+ self ,
130+ config : Config ,
131+ mesh : Mesh ,
132+ rngs : nnx .Rngs | None = None ,
133+ ):
134+ self .config = config
135+ self .mesh = mesh
132136 """Initializes the MTP block and its dependencies for the test."""
133- self .shared_embedding = embeddings .embed_as_linen (
134- mesh = self .mesh ,
137+ self .shared_embedding = embeddings .Embed (
135138 num_embeddings = self .config .vocab_size ,
136139 num_features = self .config .base_emb_dim ,
137140 config = self .config ,
138- name = "shared_embedding" ,
141+ mesh = self .mesh ,
142+ rngs = rngs ,
139143 )
140- self .decoder = Decoder (config = self .config , mesh = self .mesh , name = "decoder_for_mtp" )
141- self .mtp_block = multi_token_prediction .MultiTokenPredictionBlock (
144+ decoder_for_mtp = Decoder (config = self .config , mesh = self .mesh , name = "decoder_for_mtp" )
145+
146+ self .multi_token_prediction_block = multi_token_prediction .MultiTokenPredictionBlock (
142147 config = self .config ,
143148 mesh = self .mesh ,
144149 name = "mtp_block" ,
145150 transformer_layer_module = DecoderLayer ,
146- decoder = self . decoder ,
151+ decoder = decoder_for_mtp ,
147152 )
153+ self .mtp_block = nnx_wrappers .ToNNX (self .multi_token_prediction_block , rngs = nnx .Rngs (params = 0 ))
148154
149155 def __call__ (
150156 self ,
@@ -156,6 +162,7 @@ def __call__(
156162 decoder_segment_ids ,
157163 model_mode ,
158164 deterministic ,
165+ mutable = None ,
159166 ):
160167 return self .mtp_block (
161168 self .shared_embedding ,
@@ -167,6 +174,7 @@ def __call__(
167174 decoder_segment_ids ,
168175 model_mode ,
169176 deterministic ,
177+ mutable = mutable ,
170178 )
171179
172180
@@ -181,6 +189,7 @@ def setUp(self):
181189 skip_jax_distributed_system = True ,
182190 mtp_num_layers = 2 ,
183191 )
192+ self .nnx_rngs = nnx .Rngs (params = 0 )
184193 self .rng = jax .random .PRNGKey (43 )
185194 devices_array = maxtext_utils .create_device_mesh (self .cfg )
186195 self .mesh = Mesh (devices_array , self .cfg .mesh_axes )
@@ -195,23 +204,11 @@ def setUp(self):
195204 self .position_ids = jnp .arange (self .seq_len , dtype = jnp .int32 ).reshape (1 , - 1 )
196205 self .decoder_segment_ids = jnp .ones ((self .batch_size , self .seq_len ), dtype = jnp .int32 )
197206
198- self .test_model = MTPBlockTestModel (config = self .cfg , mesh = self .mesh )
199- self .variables = self .test_model .init (
200- {"params" : self .init_rng , "dropout" : self .init_rng },
201- self .main_hidden_state ,
202- self .input_ids ,
203- self .target_ids ,
204- self .target_mask ,
205- self .position_ids ,
206- self .decoder_segment_ids ,
207- model_mode = MODEL_MODE_TRAIN ,
208- deterministic = True ,
209- )
207+ self .test_model = MTPBlockTestModel (config = self .cfg , mesh = self .mesh , rngs = self .nnx_rngs )
210208
211209 def test_sow_functionality (self ):
212210 """Verifies that the block correctly sows losses and weights."""
213- _ , captured_vars = self .test_model .apply (
214- self .variables ,
211+ self .test_model (
215212 self .main_hidden_state ,
216213 self .input_ids ,
217214 self .target_ids ,
@@ -222,25 +219,24 @@ def test_sow_functionality(self):
222219 model_mode = MODEL_MODE_TRAIN ,
223220 mutable = ["mtp_losses" ],
224221 )
225- self .assertIn ( "mtp_losses" , captured_vars )
226- sown_data = maxtext_utils . get_nested_value ( captured_vars , ( "mtp_losses" , " mtp_block" ), {})
227- self .assertIn ( "losses" , sown_data )
228- self .assertEqual (len (sown_data [ "losses" ] ), self .cfg .mtp_num_layers )
222+ self .assertTrue ( hasattr ( self . test_model . mtp_block , "losses" ) )
223+ mtp_loss = self . test_model . mtp_block . losses
224+ self .assertTrue ( type ( mtp_loss ). __name__ , "mtp_losses" )
225+ self .assertEqual (len (mtp_loss ), self .cfg .mtp_num_layers )
229226
230227 def test_no_sow_during_init (self ):
231228 """Verifies no losses are sown during model initialization."""
232229 # `self.variables` was created by `.init()`. We inspect it to ensure
233230 # our `if not self.is_initializing()` check worked.
234- self .assertNotIn ( "mtp_losses" , self .variables )
231+ self .assertFalse ( hasattr ( self .test_model . mtp_block , "losses" ) )
235232
236233 def test_loss_aggregation_logic (self ):
237234 """
238235 Tests the full 'sow and reap' cycle, mimicking the logic from train.py
239236 to ensure the final loss calculation is correct.
240237 """
241238 # 1. Run the forward pass and capture the sown variables.
242- _ , captured_vars = self .test_model .apply (
243- self .variables ,
239+ self .test_model (
244240 self .main_hidden_state ,
245241 self .input_ids ,
246242 self .target_ids ,
@@ -250,26 +246,21 @@ def test_loss_aggregation_logic(self):
250246 deterministic = False ,
251247 mutable = ["mtp_losses" ],
252248 model_mode = MODEL_MODE_TRAIN ,
253- rngs = {"dropout" : self .rng },
254249 )
255250
256251 # This section of the test now *becomes* the logic from train.py
257252 # -------------------------------------------------------------
258253 final_loss_for_gradient = 100.0 # A dummy main loss
259254 mtp_loss_for_logging = 0.0
260255
261- # 2. Define the exact path to retrieve the sown variables.
262- losses_path = ("mtp_losses" , "mtp_block" , "losses" )
263- weights_path = ("mtp_losses" , "mtp_block" , "weights" )
264-
265- # 3. Use the standard utility to get the data.
266- mtp_losses = maxtext_utils .get_nested_value (captured_vars , losses_path , default = ())
267- mtp_weights = maxtext_utils .get_nested_value (captured_vars , weights_path , default = ())
256+ # 2. Get the weight and losses.
257+ mtp_losses = self .test_model .mtp_block .losses .value
258+ mtp_weights = self .test_model .mtp_block .weights .value
268259
269- # 4 . Perform the aggregation logic exactly as in `loss_fn`.
260+ # 3 . Perform the aggregation logic exactly as in `loss_fn`.
270261 if mtp_losses :
271- sum_of_all_mtp_losses = jnp .sum (jnp .array (mtp_losses ))
272- sum_of_all_mtp_weights = jnp .sum (jnp .array (mtp_weights ))
262+ sum_of_all_mtp_losses = jnp .sum (jnp .array (mtp_losses )). item ()
263+ sum_of_all_mtp_weights = jnp .sum (jnp .array (mtp_weights )). item ()
273264
274265 self .assertGreater (sum_of_all_mtp_weights , 0 )
275266
@@ -280,7 +271,7 @@ def test_loss_aggregation_logic(self):
280271 mtp_loss_for_logging = scaled_mtp_loss
281272 # -------------------------------------------------------------
282273
283- # 5 . Assert that the final values are correct.
274+ # 4 . Assert that the final values are correct.
284275 # The final loss should have increased from its base value.
285276 self .assertGreater (final_loss_for_gradient , 100.0 )
286277 # The logged MTP loss should be a valid, positive number.
0 commit comments