Skip to content

Commit 1a50f57

Browse files
Migrate test modules to NNX
1 parent ed517cf commit 1a50f57

3 files changed

Lines changed: 139 additions & 83 deletions

File tree

tests/maxtext_utils_test.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828
from flax import linen as nn
2929
from flax.core.scope import FrozenVariableDict
30-
from flax.linen import Dense
3130
from flax.training import train_state
31+
from flax import nnx
3232

3333
import optax
3434

@@ -156,39 +156,61 @@ def test_init_training_state(self):
156156
)
157157

158158

159-
class ModelWithMultipleCollections(nn.Module):
159+
@nnx.register_variable_name("special_variables")
160+
class SpecialVariables(nnx.Variable):
161+
pass
162+
163+
164+
class ModelWithMultipleCollections(nnx.Module):
160165
"""
161166
A simple model that has variables in multiple collections - "params" and "special_variables"
162167
"""
163168

164-
dense: Dense = nn.Dense(4)
165-
166-
def setup(self):
167-
self.kernel = self.variable("special_variables", "my_first_kernel", lambda: jnp.ones((4, 5)))
169+
def __init__(self, input_dim: int, rngs: nnx.Rngs | None = None):
170+
self.dense = nnx.Linear(input_dim, 4, rngs=rngs)
171+
self.my_first_kernel = SpecialVariables(jnp.ones((4, 5)))
168172

169173
def __call__(self, x, y, encoder_images=None, nnx_method=None, model_mode=None):
170174
x = self.dense(x)
171-
x = x @ self.kernel.value
175+
x = x @ self.my_first_kernel
172176
return x
173177

174178

179+
class TrainState(train_state.TrainState):
180+
other_variables: nnx.State
181+
182+
175183
class MaxUtilsInitStateWithMultipleCollections(unittest.TestCase):
176184
"""test class for multiple collection state in maxutils"""
177185

178186
def setUp(self):
179187
self.config = pyconfig.initialize(
180188
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], enable_checkpointing=False
181189
)
182-
self.model = ModelWithMultipleCollections()
183-
self.key1, self.key2, self.key3 = random.split(random.key(0), num=3)
184-
self.input = random.normal(self.key1, (self.config.global_batch_size_to_load, self.config.max_target_length))
185-
self.params = self.model.init(self.key2, self.input, self.input)
190+
self.model = ModelWithMultipleCollections(self.config.max_target_length, nnx.Rngs(0))
191+
self.key = random.key(0)
186192
self.tx = optax.adam(learning_rate=0.001)
187193

188194
def _test_init_initial_state_driver(self, is_training):
189195
"""test initiating of the initial state driver"""
190-
state_under_test = maxtext_utils.init_initial_state(self.model, self.tx, self.config, is_training, self.key3)
191-
self.assertEqual(state_under_test.apply_fn, self.model.apply)
196+
if is_training:
197+
self.model.train()
198+
else:
199+
self.model.eval()
200+
201+
graphdef, params, other_variables = nnx.split(self.model, nnx.Param, ...)
202+
203+
state_under_test = None
204+
if is_training:
205+
state_under_test = TrainState.create(
206+
apply_fn=graphdef.apply, params=params, other_variables=other_variables, tx=self.tx
207+
)
208+
else:
209+
state_under_test = TrainState(
210+
step=0, apply_fn=graphdef.apply, params=params, other_variables=other_variables, tx=None, opt_state={}
211+
)
212+
213+
self.assertEqual(state_under_test.apply_fn, graphdef.apply)
192214
if is_training:
193215
self.assertEqual(state_under_test.tx, self.tx)
194216
self.assertNotEqual(state_under_test.opt_state, {})
@@ -197,11 +219,11 @@ def _test_init_initial_state_driver(self, is_training):
197219
self.assertEqual(state_under_test.opt_state, {})
198220
self.assertEqual(
199221
max_utils.calculate_num_params_from_pytree(state_under_test.params),
200-
max_utils.calculate_num_params_from_pytree(self.params),
222+
max_utils.calculate_num_params_from_pytree(params),
201223
)
202-
self.assertEqual(len(self.params), len(state_under_test.params))
203-
self.assertIn("special_variables", state_under_test.params)
204-
self.assertIn("params", state_under_test.params)
224+
self.assertEqual(len(params), len(state_under_test.params))
225+
self.assertIsInstance(state_under_test.other_variables["my_first_kernel"], SpecialVariables)
226+
self.assertTrue(hasattr(state_under_test, "params"))
205227

206228
def test_initial_train_state(self):
207229
self._test_init_initial_state_driver(True)

tests/multi_token_prediction_test.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import jax
2020
import jax.numpy as jnp
2121
from jax.sharding import Mesh
22-
from flax import linen as nn
22+
from flax import nnx
2323

2424
from MaxText.common_types import Config
2525
from MaxText import max_logging, pyconfig
@@ -29,7 +29,7 @@
2929
from MaxText.layers import multi_token_prediction # The class under test
3030
from MaxText.layers import embeddings
3131
from MaxText.common_types import MODEL_MODE_TRAIN
32-
32+
from MaxText.layers import nnx_wrappers
3333

3434
TEST_LAYER_NUM = 1
3535

@@ -122,29 +122,35 @@ def test_multi_token_prediction_layer_output(self):
122122

123123

124124
# A lightweight wrapper model for robustly testing the MTPBlock.
125-
class MTPBlockTestModel(nn.Module):
125+
class MTPBlockTestModel(nnx.Module):
126126
"""A lightweight wrapper model for testing the MTPBlock."""
127127

128-
config: Config
129-
mesh: Mesh
130-
131-
def setup(self):
128+
def __init__(
129+
self,
130+
config: Config,
131+
mesh: Mesh,
132+
rngs: nnx.Rngs | None = None,
133+
):
134+
self.config = config
135+
self.mesh = mesh
132136
"""Initializes the MTP block and its dependencies for the test."""
133-
self.shared_embedding = embeddings.embed_as_linen(
134-
mesh=self.mesh,
137+
self.shared_embedding = embeddings.Embed(
135138
num_embeddings=self.config.vocab_size,
136139
num_features=self.config.base_emb_dim,
137140
config=self.config,
138-
name="shared_embedding",
141+
mesh=self.mesh,
142+
rngs=rngs,
139143
)
140-
self.decoder = Decoder(config=self.config, mesh=self.mesh, name="decoder_for_mtp")
141-
self.mtp_block = multi_token_prediction.MultiTokenPredictionBlock(
144+
decoder_for_mtp = Decoder(config=self.config, mesh=self.mesh, name="decoder_for_mtp")
145+
146+
self.multi_token_prediction_block = multi_token_prediction.MultiTokenPredictionBlock(
142147
config=self.config,
143148
mesh=self.mesh,
144149
name="mtp_block",
145150
transformer_layer_module=DecoderLayer,
146-
decoder=self.decoder,
151+
decoder=decoder_for_mtp,
147152
)
153+
self.mtp_block = nnx_wrappers.ToNNX(self.multi_token_prediction_block, rngs=nnx.Rngs(params=0))
148154

149155
def __call__(
150156
self,
@@ -156,6 +162,7 @@ def __call__(
156162
decoder_segment_ids,
157163
model_mode,
158164
deterministic,
165+
mutable=None,
159166
):
160167
return self.mtp_block(
161168
self.shared_embedding,
@@ -167,6 +174,7 @@ def __call__(
167174
decoder_segment_ids,
168175
model_mode,
169176
deterministic,
177+
mutable=mutable,
170178
)
171179

172180

@@ -181,6 +189,7 @@ def setUp(self):
181189
skip_jax_distributed_system=True,
182190
mtp_num_layers=2,
183191
)
192+
self.nnx_rngs = nnx.Rngs(params=0)
184193
self.rng = jax.random.PRNGKey(43)
185194
devices_array = maxtext_utils.create_device_mesh(self.cfg)
186195
self.mesh = Mesh(devices_array, self.cfg.mesh_axes)
@@ -195,23 +204,11 @@ def setUp(self):
195204
self.position_ids = jnp.arange(self.seq_len, dtype=jnp.int32).reshape(1, -1)
196205
self.decoder_segment_ids = jnp.ones((self.batch_size, self.seq_len), dtype=jnp.int32)
197206

198-
self.test_model = MTPBlockTestModel(config=self.cfg, mesh=self.mesh)
199-
self.variables = self.test_model.init(
200-
{"params": self.init_rng, "dropout": self.init_rng},
201-
self.main_hidden_state,
202-
self.input_ids,
203-
self.target_ids,
204-
self.target_mask,
205-
self.position_ids,
206-
self.decoder_segment_ids,
207-
model_mode=MODEL_MODE_TRAIN,
208-
deterministic=True,
209-
)
207+
self.test_model = MTPBlockTestModel(config=self.cfg, mesh=self.mesh, rngs=self.nnx_rngs)
210208

211209
def test_sow_functionality(self):
212210
"""Verifies that the block correctly sows losses and weights."""
213-
_, captured_vars = self.test_model.apply(
214-
self.variables,
211+
self.test_model(
215212
self.main_hidden_state,
216213
self.input_ids,
217214
self.target_ids,
@@ -222,25 +219,24 @@ def test_sow_functionality(self):
222219
model_mode=MODEL_MODE_TRAIN,
223220
mutable=["mtp_losses"],
224221
)
225-
self.assertIn("mtp_losses", captured_vars)
226-
sown_data = maxtext_utils.get_nested_value(captured_vars, ("mtp_losses", "mtp_block"), {})
227-
self.assertIn("losses", sown_data)
228-
self.assertEqual(len(sown_data["losses"]), self.cfg.mtp_num_layers)
222+
self.assertTrue(hasattr(self.test_model.mtp_block, "losses"))
223+
mtp_loss = self.test_model.mtp_block.losses
224+
self.assertTrue(type(mtp_loss).__name__, "mtp_losses")
225+
self.assertEqual(len(mtp_loss), self.cfg.mtp_num_layers)
229226

230227
def test_no_sow_during_init(self):
231228
"""Verifies no losses are sown during model initialization."""
232229
# `self.variables` was created by `.init()`. We inspect it to ensure
233230
# our `if not self.is_initializing()` check worked.
234-
self.assertNotIn("mtp_losses", self.variables)
231+
self.assertFalse(hasattr(self.test_model.mtp_block, "losses"))
235232

236233
def test_loss_aggregation_logic(self):
237234
"""
238235
Tests the full 'sow and reap' cycle, mimicking the logic from train.py
239236
to ensure the final loss calculation is correct.
240237
"""
241238
# 1. Run the forward pass and capture the sown variables.
242-
_, captured_vars = self.test_model.apply(
243-
self.variables,
239+
self.test_model(
244240
self.main_hidden_state,
245241
self.input_ids,
246242
self.target_ids,
@@ -250,26 +246,21 @@ def test_loss_aggregation_logic(self):
250246
deterministic=False,
251247
mutable=["mtp_losses"],
252248
model_mode=MODEL_MODE_TRAIN,
253-
rngs={"dropout": self.rng},
254249
)
255250

256251
# This section of the test now *becomes* the logic from train.py
257252
# -------------------------------------------------------------
258253
final_loss_for_gradient = 100.0 # A dummy main loss
259254
mtp_loss_for_logging = 0.0
260255

261-
# 2. Define the exact path to retrieve the sown variables.
262-
losses_path = ("mtp_losses", "mtp_block", "losses")
263-
weights_path = ("mtp_losses", "mtp_block", "weights")
264-
265-
# 3. Use the standard utility to get the data.
266-
mtp_losses = maxtext_utils.get_nested_value(captured_vars, losses_path, default=())
267-
mtp_weights = maxtext_utils.get_nested_value(captured_vars, weights_path, default=())
256+
# 2. Get the weight and losses.
257+
mtp_losses = self.test_model.mtp_block.losses.value
258+
mtp_weights = self.test_model.mtp_block.weights.value
268259

269-
# 4. Perform the aggregation logic exactly as in `loss_fn`.
260+
# 3. Perform the aggregation logic exactly as in `loss_fn`.
270261
if mtp_losses:
271-
sum_of_all_mtp_losses = jnp.sum(jnp.array(mtp_losses))
272-
sum_of_all_mtp_weights = jnp.sum(jnp.array(mtp_weights))
262+
sum_of_all_mtp_losses = jnp.sum(jnp.array(mtp_losses)).item()
263+
sum_of_all_mtp_weights = jnp.sum(jnp.array(mtp_weights)).item()
273264

274265
self.assertGreater(sum_of_all_mtp_weights, 0)
275266

@@ -280,7 +271,7 @@ def test_loss_aggregation_logic(self):
280271
mtp_loss_for_logging = scaled_mtp_loss
281272
# -------------------------------------------------------------
282273

283-
# 5. Assert that the final values are correct.
274+
# 4. Assert that the final values are correct.
284275
# The final loss should have increased from its base value.
285276
self.assertGreater(final_loss_for_gradient, 100.0)
286277
# The logged MTP loss should be a valid, positive number.

0 commit comments

Comments
 (0)