Skip to content

Commit 25b5de7

Browse files
committed
Add support z-loss in pre-training
1 parent fcaecd2 commit 25b5de7

6 files changed

Lines changed: 159 additions & 20 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ sliding_window_size: 0
334334
chunk_attn_window_size: 0
335335
attn_logits_soft_cap: 0.0
336336
final_logits_soft_cap: 0.0
337+
z_loss_multiplier: 0.0
337338
use_post_attn_norm: False
338339
use_post_ffw_norm: False
339340
mla_naive_kvcache: True

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ class Logits(BaseModel):
462462
None,
463463
description="Soft-cap value for the final logits. None or 0.0 means no cap.",
464464
)
465+
z_loss_multiplier: float = Field(0.0, description="The multiplier for the z-loss (e.g., 1e-4). 0.0 to disable.")
465466

466467

467468
class Attention(BaseModel):

src/maxtext/trainers/pre_train/train.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,20 +136,32 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
136136
if config.num_vocab_tiling > 1:
137137
hidden_state_key = ("intermediates", "decoder", "hidden_states")
138138
hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0]
139-
total_loss = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train)
139+
total_loss, total_z_loss = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train)
140140
else:
141141
one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size)
142-
xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets)
142+
xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier)
143+
143144
xent = sharding.maybe_shard_with_logical(
144145
xent,
145146
("activation_embed_and_logits_batch", "activation_length"),
146147
model.mesh,
147148
config.shard_mode,
148149
debug_sharding=config.debug_sharding,
149150
)
151+
z_loss = sharding.maybe_shard_with_logical(
152+
z_loss,
153+
("activation_embed_and_logits_batch", "activation_length"),
154+
model.mesh,
155+
config.shard_mode,
156+
debug_sharding=config.debug_sharding,
157+
)
158+
150159
# Mask out paddings at the end of each example.
151160
xent = xent * (data["targets_segmentation"] != 0)
161+
z_loss = z_loss * (data["targets_segmentation"] != 0)
162+
152163
total_loss = jnp.sum(xent)
164+
total_z_loss = jnp.sum(z_loss)
153165
else:
154166
# Flax NNX model
155167
logits = model(
@@ -164,11 +176,17 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
164176
)
165177
intermediate_outputs = {}
166178
one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size)
167-
xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets)
179+
xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier)
180+
168181
xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length"))
182+
z_loss = nn.with_logical_constraint(z_loss, ("activation_embed_and_logits_batch", "activation_length"))
183+
169184
# Mask out paddings at the end of each example.
170185
xent = xent * (data["targets_segmentation"] != 0)
186+
z_loss = z_loss * (data["targets_segmentation"] != 0)
187+
171188
total_loss = jnp.sum(xent)
189+
total_z_loss = jnp.sum(z_loss)
172190

173191
total_weights = jnp.sum(data["targets_segmentation"] != 0)
174192
# If gradient accumulation is enabled, we don't need to divide total_loss
@@ -188,6 +206,9 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
188206
# updates and scaling internally.
189207
loss = total_loss / (total_weights + EPS)
190208

209+
# We keep z-loss normalized by total_weights.
210+
total_z_loss = total_z_loss / (total_weights + EPS)
211+
191212
# Calculate and Add MTP Loss
192213
mtp_loss = 0.0
193214
if config.mtp_num_layers > 0 and is_train:
@@ -230,6 +251,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
230251
aux = {
231252
"intermediate_outputs": intermediate_outputs,
232253
"total_loss": total_loss,
254+
"z_loss": total_z_loss,
233255
"total_weights": total_weights,
234256
"moe_lb_loss": moe_lb_loss,
235257
"moe_bias_updates": moe_bias_updates,
@@ -302,6 +324,7 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
302324
intermediate_outputs = aux["intermediate_outputs"]
303325
total_weights = aux["total_weights"]
304326
moe_lb_loss = aux["moe_lb_loss"]
327+
z_loss = aux["z_loss"]
305328
moe_bias_updates = aux["moe_bias_updates"]
306329
mtp_loss = aux["mtp_loss"]
307330

@@ -345,6 +368,7 @@ def move(path, value):
345368

346369
scalar_metrics = {
347370
"learning/loss": loss,
371+
"learning/z_loss": z_loss,
348372
"learning/moe_lb_loss": moe_lb_loss,
349373
"learning/mtp_loss": mtp_loss,
350374
"learning/total_weights": total_weights,
@@ -395,12 +419,14 @@ def eval_step(model, config, state, data, dropout_rng):
395419
mtp_acceptance_rate = calculate_mtp_acceptance_rate(aux["intermediate_outputs"], config)
396420

397421
total_loss = aux["total_loss"]
422+
z_loss = aux["z_loss"]
398423
total_weights = aux["total_weights"]
399424
moe_lb_loss = aux["moe_lb_loss"]
400425
mtp_loss = aux["mtp_loss"]
401426
metrics = {
402427
"scalar": {
403428
"evaluation/loss": loss,
429+
"evaluation/z_loss": z_loss,
404430
"evaluation/total_loss": total_loss,
405431
"evaluation/total_weights": total_weights,
406432
"evaluation/moe_lb_loss": moe_lb_loss,

src/maxtext/utils/vocabulary_tiling.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def vocab_tiling_linen_loss(
5252
params: The model parameters.
5353
is_train: A boolean indicating if the model is in training mode.
5454
Returns:
55-
The total cross-entropy loss computed via vocab tiling.
55+
A tuple of (total_loss, total_z_loss) computed via vocab tiling.
5656
"""
5757
labels = data["targets"]
5858
segmentation = data["targets_segmentation"]
@@ -112,8 +112,8 @@ def chunked_cross_entropy_loss(gathered_params, hidden_states, labels, segmentat
112112
"""
113113
Calculates the total cross-entropy loss using vocab tiling.
114114
"""
115-
total_loss, _ = _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation)
116-
return total_loss
115+
(total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation)
116+
return total_loss, total_z_loss
117117

118118
def _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation):
119119
batch_size, seq_len, emb_dim = hidden_states.shape
@@ -126,7 +126,8 @@ def _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segm
126126
reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
127127

128128
# Scan body accumulates loss from each tile given chunked hidden states and labels
129-
def _fwd_scan_body(loss_accumulator, chunk_data):
129+
def _fwd_scan_body(accumulators, chunk_data):
130+
loss_accumulator, z_loss_accumulator = accumulators
130131
hidden_chunk, label_chunk, segmentation_chunk = chunk_data
131132
hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec)
132133
label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec)
@@ -141,14 +142,20 @@ def _fwd_scan_body(loss_accumulator, chunk_data):
141142
)
142143
chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec)
143144
one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size)
144-
chunk_xent, _ = max_utils.cross_entropy_with_logits(chunk_logits, one_hot_label_chunk)
145+
chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits(
146+
chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier
147+
)
148+
145149
masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0))
150+
masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0))
151+
146152
loss_accumulator += masked_xent
147-
return loss_accumulator, None
153+
z_loss_accumulator += masked_z_loss
154+
return (loss_accumulator, z_loss_accumulator), None
148155

149-
initial_loss = 0.0
150-
total_loss, _ = jax.lax.scan(
151-
_fwd_scan_body, initial_loss, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
156+
initial_acc = (0.0, 0.0)
157+
(total_loss, total_z_loss), _ = jax.lax.scan(
158+
_fwd_scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
152159
)
153160
residuals = (
154161
gathered_params,
@@ -160,9 +167,13 @@ def _fwd_scan_body(loss_accumulator, chunk_data):
160167
emb_dim,
161168
)
162169

163-
return total_loss, residuals
170+
return (total_loss, total_z_loss), residuals
171+
172+
def _chunked_cross_entropy_loss_bwd(residuals, cotangents):
173+
# Unpack the cotangents tuple. We ignore the z_loss cotangent since the gradients
174+
# of the z_loss term are already factored into the loss_cotangent.
175+
loss_cotangent, _ = cotangents
164176

165-
def _chunked_cross_entropy_loss_bwd(residuals, loss_cotangent):
166177
gathered_params, reshaped_hidden_states, reshaped_labels, reshaped_segmentation, batch_size, seq_len, emb_dim = (
167178
residuals
168179
)
@@ -176,7 +187,7 @@ def _single_chunk_loss_fn(input_params, input_hidden_chunk, input_label_chunk, i
176187
)
177188
chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec)
178189
one_hot_label_chunk = jax.nn.one_hot(input_label_chunk, config.vocab_size)
179-
xent, _ = max_utils.cross_entropy_with_logits(chunk_logits, one_hot_label_chunk)
190+
xent, _ = max_utils.cross_entropy_with_logits(chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier)
180191
return jnp.sum(xent * (input_segmentation_chunk != 0))
181192

182193
def _bwd_scan_body(grad_params_acc, chunk_data):
@@ -228,11 +239,11 @@ def _bwd_scan_body(grad_params_acc, chunk_data):
228239

229240
chunked_cross_entropy_loss.defvjp(_chunked_cross_entropy_loss_fwd, _chunked_cross_entropy_loss_bwd)
230241

231-
total_loss = chunked_cross_entropy_loss(
242+
total_loss, total_z_loss = chunked_cross_entropy_loss(
232243
gathered_params,
233244
hidden_states,
234245
labels,
235246
segmentation,
236247
)
237248

238-
return total_loss
249+
return total_loss, total_z_loss

tests/unit/max_utils_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,35 @@ def test_t5x_cross_entropy(self):
8686
# Compare results
8787
self.assertTrue(jax.numpy.allclose(optax_xent, t5x_xent, rtol=1e-05, atol=1e-08, equal_nan=False))
8888

89+
def test_cross_entropy_with_z_loss(self):
90+
"""Tests the exact mathematical output of the z-loss penalty."""
91+
# Shape [2, 3] to test across multiple dimensions
92+
logits = jnp.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [1.0, 1.0, 1.0]]])
93+
# Target indices: [2, 1], [0, 2]
94+
targets = jnp.array([[2, 1], [0, 2]])
95+
one_hot_targets = jax.nn.one_hot(targets, 3)
96+
97+
z_loss_multiplier = 1e-4
98+
99+
# 1. Run without z-loss
100+
total_loss_no_z, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=0.0)
101+
102+
# 2. Run with z-loss
103+
total_loss_with_z, z_loss_only = max_utils.cross_entropy_with_logits(
104+
logits, one_hot_targets, z_loss=z_loss_multiplier
105+
)
106+
107+
# 3. Calculate expected z-loss manually
108+
# Expected log_z = log(sum(exp(logits), axis=-1))
109+
expected_log_z = jax.scipy.special.logsumexp(logits, axis=-1)
110+
expected_z_loss = z_loss_multiplier * jnp.square(expected_log_z)
111+
112+
# Compare isolated z_loss component
113+
self.assertTrue(jnp.allclose(z_loss_only, expected_z_loss, rtol=1e-5, atol=1e-8))
114+
115+
# Compare total loss aggregation
116+
self.assertTrue(jnp.allclose(total_loss_with_z, total_loss_no_z + z_loss_only, rtol=1e-5, atol=1e-8))
117+
89118

90119
class MaxUtilsCustomMesh(unittest.TestCase):
91120
"""Tests for the is_valid_custom_mesh function in max_utils.py"""

tests/unit/tiling_test.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
"""
2020

2121
import unittest
22+
import pytest
23+
2224
from flax import linen as nn
2325
import jax
2426
import jax.numpy as jnp
2527
from jax.sharding import Mesh
28+
2629
from maxtext.configs import pyconfig
2730
from maxtext.common.common_types import Config
2831
from maxtext.common.common_types import MODEL_MODE_TRAIN
@@ -31,8 +34,8 @@
3134
from maxtext.utils import max_utils
3235
from maxtext.utils import maxtext_utils
3336
from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss
37+
3438
from tests.utils.test_helpers import get_test_config_path
35-
import pytest
3639

3740

3841
def compute_loss_linen(intermediate_outputs, logits, data, config, model, params, is_train):
@@ -42,10 +45,10 @@ def compute_loss_linen(intermediate_outputs, logits, data, config, model, params
4245
if config.num_vocab_tiling > 1:
4346
hidden_state_key = ("intermediates", "decoder", "hidden_states")
4447
hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0]
45-
total_loss = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train)
48+
total_loss, _ = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train)
4649
else:
4750
one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size)
48-
xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets)
51+
xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier)
4952
xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length"))
5053
# Mask out paddings at the end of each example.
5154
xent = xent * (data["targets_segmentation"] != 0)
@@ -186,6 +189,74 @@ def test_gradient_accumulation(self):
186189
"Gradients of embedding table do not match for GA.",
187190
)
188191

192+
@pytest.mark.tpu_only
193+
def test_vocab_tiling_gradient_with_z_loss(self):
194+
"""
195+
Tests loss and gradient correctness when z-loss is enabled, comparing
196+
standard computation vs. vocabulary tiling computation.
197+
"""
198+
cfg_non_tiling = pyconfig.initialize(
199+
self.base_config,
200+
run_name="grad_test_z_loss_no_tiling",
201+
enable_checkpointing=False,
202+
enable_dropout=False,
203+
max_target_length=self.seq_len,
204+
per_device_batch_size=self.batch_size,
205+
logits_via_embedding=False,
206+
base_num_decoder_layers=0,
207+
dtype="float32",
208+
matmul_precision="high",
209+
num_vocab_tiling=1,
210+
z_loss_multiplier=1e-4, # Enable z-loss
211+
)
212+
quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling)
213+
devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling)
214+
mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes)
215+
model_non_tiling = models.transformer_as_linen(
216+
cfg_non_tiling, mesh=mesh_non_tiling, quant=quant_non_tiling, model_mode=MODEL_MODE_TRAIN
217+
)
218+
219+
rng_model, rng_targets = jax.random.split(self.rng)
220+
221+
params = model_non_tiling.init(
222+
{"params": rng_model, "dropout": rng_model},
223+
self.dummy_inputs,
224+
self.dummy_inputs,
225+
)
226+
227+
data = {
228+
"targets": jax.random.randint(rng_targets, (self.batch_size, self.seq_len), 0, cfg_non_tiling.vocab_size),
229+
"targets_segmentation": jnp.ones((self.batch_size, self.seq_len)),
230+
}
231+
232+
loss_non_tiling, grads_non_tiling = self.get_grads(cfg_non_tiling, params, data)
233+
234+
cfg_tiling = pyconfig.initialize(
235+
self.base_config,
236+
run_name="grad_test_z_loss_with_tiling",
237+
enable_checkpointing=False,
238+
enable_dropout=False,
239+
max_target_length=self.seq_len,
240+
per_device_batch_size=self.batch_size,
241+
logits_via_embedding=False,
242+
base_num_decoder_layers=0,
243+
dtype="float32",
244+
matmul_precision="high",
245+
num_vocab_tiling=4,
246+
z_loss_multiplier=1e-4, # Enable z-loss
247+
)
248+
loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data)
249+
250+
# Loss correctness test
251+
assert jnp.allclose(loss_non_tiling, loss_tiling, rtol=self.rtol), "Losses do not match when z-loss is enabled."
252+
253+
# Gradient correctness test
254+
self.assert_pytrees_all_close(
255+
grads_non_tiling,
256+
grads_tiling,
257+
"Gradients do not match for vocab tiling when z-loss is enabled.",
258+
)
259+
189260
@pytest.mark.tpu_only
190261
def test_vocab_tiling_gradient_non_tied_embedding(self):
191262
"""

0 commit comments

Comments
 (0)