@@ -136,20 +136,32 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
136136 if config .num_vocab_tiling > 1 :
137137 hidden_state_key = ("intermediates" , "decoder" , "hidden_states" )
138138 hidden_states = maxtext_utils .get_nested_value (intermediate_outputs , hidden_state_key )[0 ]
139- total_loss = vocab_tiling_linen_loss (hidden_states , data , config , model , params , is_train )
139+ total_loss , total_z_loss = vocab_tiling_linen_loss (hidden_states , data , config , model , params , is_train )
140140 else :
141141 one_hot_targets = jax .nn .one_hot (data ["targets" ], config .vocab_size )
142- xent , _ = max_utils .cross_entropy_with_logits (logits , one_hot_targets )
142+ xent , z_loss = max_utils .cross_entropy_with_logits (logits , one_hot_targets , z_loss = config .z_loss_multiplier )
143+
143144 xent = sharding .maybe_shard_with_logical (
144145 xent ,
145146 ("activation_embed_and_logits_batch" , "activation_length" ),
146147 model .mesh ,
147148 config .shard_mode ,
148149 debug_sharding = config .debug_sharding ,
149150 )
151+ z_loss = sharding .maybe_shard_with_logical (
152+ z_loss ,
153+ ("activation_embed_and_logits_batch" , "activation_length" ),
154+ model .mesh ,
155+ config .shard_mode ,
156+ debug_sharding = config .debug_sharding ,
157+ )
158+
150159 # Mask out paddings at the end of each example.
151160 xent = xent * (data ["targets_segmentation" ] != 0 )
161+ z_loss = z_loss * (data ["targets_segmentation" ] != 0 )
162+
152163 total_loss = jnp .sum (xent )
164+ total_z_loss = jnp .sum (z_loss )
153165 else :
154166 # Flax NNX model
155167 logits = model (
@@ -164,11 +176,17 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
164176 )
165177 intermediate_outputs = {}
166178 one_hot_targets = jax .nn .one_hot (data ["targets" ], config .vocab_size )
167- xent , _ = max_utils .cross_entropy_with_logits (logits , one_hot_targets )
179+ xent , z_loss = max_utils .cross_entropy_with_logits (logits , one_hot_targets , z_loss = config .z_loss_multiplier )
180+
168181 xent = nn .with_logical_constraint (xent , ("activation_embed_and_logits_batch" , "activation_length" ))
182+ z_loss = nn .with_logical_constraint (z_loss , ("activation_embed_and_logits_batch" , "activation_length" ))
183+
169184 # Mask out paddings at the end of each example.
170185 xent = xent * (data ["targets_segmentation" ] != 0 )
186+ z_loss = z_loss * (data ["targets_segmentation" ] != 0 )
187+
171188 total_loss = jnp .sum (xent )
189+ total_z_loss = jnp .sum (z_loss )
172190
173191 total_weights = jnp .sum (data ["targets_segmentation" ] != 0 )
174192 # If gradient accumulation is enabled, we don't need to divide total_loss
@@ -188,6 +206,9 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
188206 # updates and scaling internally.
189207 loss = total_loss / (total_weights + EPS )
190208
209+ # We keep z-loss normalized by total_weights.
210+ total_z_loss = total_z_loss / (total_weights + EPS )
211+
191212 # Calculate and Add MTP Loss
192213 mtp_loss = 0.0
193214 if config .mtp_num_layers > 0 and is_train :
@@ -230,6 +251,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
230251 aux = {
231252 "intermediate_outputs" : intermediate_outputs ,
232253 "total_loss" : total_loss ,
254+ "z_loss" : total_z_loss ,
233255 "total_weights" : total_weights ,
234256 "moe_lb_loss" : moe_lb_loss ,
235257 "moe_bias_updates" : moe_bias_updates ,
@@ -302,6 +324,7 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
302324 intermediate_outputs = aux ["intermediate_outputs" ]
303325 total_weights = aux ["total_weights" ]
304326 moe_lb_loss = aux ["moe_lb_loss" ]
327+ z_loss = aux ["z_loss" ]
305328 moe_bias_updates = aux ["moe_bias_updates" ]
306329 mtp_loss = aux ["mtp_loss" ]
307330
@@ -345,6 +368,7 @@ def move(path, value):
345368
346369 scalar_metrics = {
347370 "learning/loss" : loss ,
371+ "learning/z_loss" : z_loss ,
348372 "learning/moe_lb_loss" : moe_lb_loss ,
349373 "learning/mtp_loss" : mtp_loss ,
350374 "learning/total_weights" : total_weights ,
@@ -395,12 +419,14 @@ def eval_step(model, config, state, data, dropout_rng):
395419 mtp_acceptance_rate = calculate_mtp_acceptance_rate (aux ["intermediate_outputs" ], config )
396420
397421 total_loss = aux ["total_loss" ]
422+ z_loss = aux ["z_loss" ]
398423 total_weights = aux ["total_weights" ]
399424 moe_lb_loss = aux ["moe_lb_loss" ]
400425 mtp_loss = aux ["mtp_loss" ]
401426 metrics = {
402427 "scalar" : {
403428 "evaluation/loss" : loss ,
429+ "evaluation/z_loss" : z_loss ,
404430 "evaluation/total_loss" : total_loss ,
405431 "evaluation/total_weights" : total_weights ,
406432 "evaluation/moe_lb_loss" : moe_lb_loss ,
0 commit comments