@@ -336,6 +336,7 @@ def mla(
336336 qk_nope_head_dim = qk_nope_head_dim ,
337337 mscale = mscale ,
338338 )
339+ query = jax .ad_checkpoint .checkpoint_name (query , "query_proj" )
339340 key , value = kv_projection (
340341 inputs ,
341342 positions ,
@@ -355,6 +356,8 @@ def mla(
355356 qk_nope_head_dim = qk_nope_head_dim ,
356357 num_query_heads = num_query_heads ,
357358 )
359+ key = jax .ad_checkpoint .checkpoint_name (key , "key_proj" )
360+ value = jax .ad_checkpoint .checkpoint_name (value , "value_proj" )
358361 out = attention_op_fn (
359362 query ,
360363 key ,
@@ -363,7 +366,9 @@ def mla(
363366 model_mode ,
364367 cached_values = [None , None ],
365368 )
369+ out = jax .ad_checkpoint .checkpoint_name (out , "attention_out" )
366370 out = dot (out , out_weights , axes = 2 )
371+ out = jax .ad_checkpoint .checkpoint_name (out , "out_proj" )
367372 return out
368373
369374
@@ -402,6 +407,7 @@ def query_projection(
402407 epsilon = epsilon ,
403408 dtype = dtype ,
404409 )
410+ low_rank_q = jax .ad_checkpoint .checkpoint_name (low_rank_q , "mla_q" )
405411 q = dot (low_rank_q , wq_b_weights )
406412
407413 # Split into non-positional and rotary parts.
@@ -451,6 +457,7 @@ def kv_projection(
451457 epsilon = kv_norm_epsilon ,
452458 dtype = dtype ,
453459 )
460+ low_rank_main = jax .ad_checkpoint .checkpoint_name (low_rank_main , "mla_kv" )
454461 key_rope = jnp .expand_dims (low_rank_rope , axis = 2 )
455462 key_rope = yarn (
456463 key_rope ,
@@ -690,6 +697,8 @@ def compute(x, w0, w1, wo, group_sizes, weights, *, wi_tile_size, wo_tile_size,
690697 )
691698 layer_w0 = gmm_fn (x , w0 , tiling = wi_tile_size )
692699 layer_w1 = gmm_fn (x , w1 , tiling = wi_tile_size )
700+ layer_w0 = jax .ad_checkpoint .checkpoint_name (layer_w0 , "mlpwi_0" )
701+ layer_w1 = jax .ad_checkpoint .checkpoint_name (layer_w1 , "mlpwi_1" )
693702 intermediate_layer = jax .nn .silu (layer_w0 ) * layer_w1
694703 intermediate_layer *= weights [:, None ]
695704 return gmm_fn (intermediate_layer , wo , tiling = wo_tile_size )
0 commit comments