Skip to content

Commit f358282

Browse files
Merge pull request #3130 from AI-Hypercomputer:bvandermoon-mtp-params
PiperOrigin-RevId: 871486617
2 parents a3c3fd1 + 18ed9ee commit f358282

2 files changed

Lines changed: 40 additions & 27 deletions

File tree

src/MaxText/layers/multi_token_prediction.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,40 @@ def __init__(
128128
model_mode=MODEL_MODE_TRAIN,
129129
)
130130

131+
132+
@property
133+
def embedding_norm(self):
134+
return getattr(self, f"mtp_{self.layer_number}_embedding_norm")
135+
136+
@embedding_norm.setter
137+
def embedding_norm(self, module):
138+
setattr(self, f"mtp_{self.layer_number}_embedding_norm", module)
139+
140+
@property
141+
def hidden_state_norm(self):
142+
return getattr(self, f"mtp_{self.layer_number}_hidden_state_norm")
143+
144+
@hidden_state_norm.setter
145+
def hidden_state_norm(self, module):
146+
setattr(self, f"mtp_{self.layer_number}_hidden_state_norm", module)
147+
148+
@property
149+
def projection_layer(self):
150+
return getattr(self, f"mtp_{self.layer_number}_projection")
151+
152+
@projection_layer.setter
153+
def projection_layer(self, module):
154+
setattr(self, f"mtp_{self.layer_number}_projection", module)
155+
156+
@property
157+
def transformer_layer(self):
158+
return getattr(self, f"mtp_{self.layer_number}_transformer_layer")
159+
160+
@transformer_layer.setter
161+
def transformer_layer(self, module):
162+
setattr(self, f"mtp_{self.layer_number}_transformer_layer", module)
163+
164+
131165
def __call__(
132166
self,
133167
prev_hidden_state: jnp.ndarray,
@@ -192,13 +226,6 @@ def __init__(
192226
self.decoder = decoder
193227
self.rngs = rngs if rngs is not None else nnx.Rngs(0)
194228

195-
# NNX Variables are exposed as Linen mutable collections by ToLinen wrapper.
196-
self.losses = mtp_losses(jnp.zeros((config.mtp_num_layers,), dtype=jnp.float32))
197-
self.weights = mtp_losses(jnp.zeros((config.mtp_num_layers,), dtype=jnp.float32))
198-
# Float32 used to avoid gradient errors; converted to int32 in acceptance rate calculation.
199-
self.mtp_preds = mtp_acceptance(jnp.zeros((1,), dtype=jnp.float32))
200-
self.mtp_mask = mtp_acceptance(jnp.zeros((1,), dtype=jnp.float32))
201-
202229
# 1-indexed to match paper convention.
203230
for k in range(1, config.mtp_num_layers + 1):
204231
layer = MultiTokenPredictionLayer(
@@ -278,11 +305,13 @@ def __call__(
278305
mtp_masks_list.append(rolled_target_mask)
279306

280307
if mtp_losses_list:
281-
self.losses.value = jnp.stack(mtp_losses_list)
282-
self.weights.value = jnp.stack(mtp_weights_list)
308+
# Not part of checkpoints, don't declare in __init__
309+
self.losses = mtp_losses(jnp.stack(mtp_losses_list))
310+
self.weights = mtp_losses(jnp.stack(mtp_weights_list))
283311
if mtp_preds_list:
284-
self.mtp_preds.value = jnp.stack(mtp_preds_list)
285-
self.mtp_mask.value = jnp.stack(mtp_masks_list)
312+
# Not part of checkpoints, don't declare in __init__
313+
self.mtp_preds = mtp_acceptance(jnp.stack(mtp_preds_list))
314+
self.mtp_mask = mtp_acceptance(jnp.stack(mtp_masks_list))
286315

287316
return {}
288317

tests/unit/multi_token_prediction_test.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -231,22 +231,6 @@ def setUp(self):
231231
rngs=self.rngs,
232232
)
233233

234-
def test_no_sow_during_init(self):
235-
"""Verifies losses/weights are initialized with zeros (NNX behavior)."""
236-
# NNX pre-initializes Variables with zeros to avoid checkpointing errors.
237-
# Unlike Linen which sows during forward pass, NNX creates Variables in __init__.
238-
initial_state = nnx.state(self.test_model)
239-
self.assertTrue(hasattr(initial_state.mtp_block, "losses"))
240-
self.assertTrue(hasattr(initial_state.mtp_block, "weights"))
241-
242-
# Verify they're initialized with zeros of correct shape.
243-
losses_val = initial_state.mtp_block.losses.value
244-
weights_val = initial_state.mtp_block.weights.value
245-
self.assertEqual(losses_val.shape, (self.cfg.mtp_num_layers,))
246-
self.assertEqual(weights_val.shape, (self.cfg.mtp_num_layers,))
247-
self.assertTrue(jnp.all(losses_val == 0.0))
248-
self.assertTrue(jnp.all(weights_val == 0.0))
249-
250234
def test_sow_functionality(self):
251235
"""Verifies that the block correctly sows losses and weights."""
252236
_ = self.test_model(

0 commit comments

Comments
 (0)