@@ -110,9 +110,9 @@ def __init__(
110110 self .dtype = config .dtype
111111 self .weight_dtype = config .weight_dtype
112112
113- self .n_heads = config .index_n_heads
114- self .head_dim = config .index_head_dim
115- self .index_topk = config .index_topk
113+ self .n_heads = config .indexer_n_heads
114+ self .head_dim = config .indexer_head_dim
115+ self .indexer_topk = config .indexer_topk
116116 self .emb_dim = config .emb_dim
117117 self .rope_head_dim = config .qk_rope_head_dim
118118 self .q_lora_rank = config .q_lora_rank
@@ -180,13 +180,13 @@ def apply_partial_rope(
180180 2. Input Layout: Indexer uses concatenated layout (interleave=False), whereas MLA uses interleaved (interleave=True).
181181
182182 Args:
183- inputs: Input array of shape [batch, seqlen, index_n_heads, index_head_dim ].
183+ inputs: Input array of shape [batch, seqlen, indexer_n_heads, indexer_head_dim ].
184184 positions: Position array of shape [batch, seqlen].
185185
186186 Returns:
187- Array with partial RoPE applied, with shape [batch, seqlen, index_n_heads, index_head_dim ]
187+ Array with partial RoPE applied, with shape [batch, seqlen, indexer_n_heads, indexer_head_dim ]
188188 """
189- # index_head_dim -> [rope_head_dim, index_head_dim - rope_head_dim]
189+ # indexer_head_dim -> [rope_head_dim, indexer_head_dim - rope_head_dim]
190190 x_pe , x_nope = jnp .split (inputs , [self .rope_head_dim ], axis = - 1 )
191191 # x_pe [B, S, H, rope_head_dim], positions [B, S]
192192 x_pe = self .rotary_embedding (x_pe , position = inputs_positions )
@@ -256,14 +256,37 @@ def __call__(
256256 b: Batch size
257257 t: Query Sequence Length (Target), note t = s here
258258 s: Key/Value Sequence Length (Source)
259- h: Number of Indexer Heads (index_n_heads )
260- d: Indexer Head Dimension (index_head_dim )
259+ h: Number of Indexer Heads (indexer_n_heads )
260+ d: Indexer Head Dimension (indexer_head_dim )
261261 """
262262 # NOTE: If sequence length <= topk, indexer always selects all tokens.
263- if self .config .max_target_length <= self .index_topk :
263+ if self .config .max_target_length <= self .indexer_topk :
264264 return None , None , None
265265
266266 bsz , seqlen , _ = inputs_q .shape # s = t = seqlen
267+ # ==============================================================================
268+ # Gradient Isolation Strategy: Main Model vs. Indexer
269+ # ==============================================================================
270+ # This creates a barrier to train both components independently, and applies
271+ # for both Dense Warm-up and Sparse Training stages:
272+ #
273+ # Forward Pass:
274+ # - The Indexer receives a detached copy of the inputs (via `stop_gradient`)
275+ # to independently calculate its scores and `indexer_loss`.
276+ #
277+ # Backward Pass (Main Model):
278+ # - The main model optimizes its weights based solely on the LM loss.
279+ # - The `indexer_mask` in the Attention layer prevents gradients from the main
280+ # loss from flowing into the Indexer's weights.
281+ #
282+ # Backward Pass (Indexer):
283+ # - Gradients from the `indexer_loss` flow back to update the Indexer's weights.
284+ # - The `stop_gradient` applied to the inputs acts as a mathematical wall, dropping
285+ # gradients to 0.0 and preventing the Indexer loss from altering the main model's
286+ # earlier layers.
287+ inputs_q = jax .lax .stop_gradient (inputs_q )
288+ low_rank_q = jax .lax .stop_gradient (low_rank_q )
289+ inputs_kv = jax .lax .stop_gradient (inputs_kv )
267290
268291 # Query Processing: Project from Latent low_rank_q
269292 q = self .wq_b (low_rank_q ) # [b, t, q_lora_rank] -> [b, t, h * d]
@@ -295,7 +318,7 @@ def __call__(
295318 indexer_score += attention_mask
296319
297320 # TopK selection based on index score
298- _ , topk_indices = jax .lax .top_k (indexer_score , k = self .index_topk ) # topk_indices [b, t, k]
321+ _ , topk_indices = jax .lax .top_k (indexer_score , k = self .indexer_topk ) # topk_indices [b, t, k]
299322
300323 # Create Sparse Index Mask: 0 and large negatives
301324 indexer_mask = self .generate_mask (topk_indices , seqlen ) # [b, t, s]
@@ -607,8 +630,8 @@ def __init__(
607630 )
608631
609632 # Initialize Indexer
610- self .use_sparse_indexer = config .use_sparse_indexer
611- if self .use_sparse_indexer :
633+ self .use_indexer = config .use_indexer
634+ if self .use_indexer :
612635 # Need two versions of rope.
613636 # MLA applies yarn with interleave layout.
614637 # Indexer applies yarn with concatenate layout.
@@ -989,6 +1012,13 @@ def calculate_indexer_loss(
9891012 Returns:
9901013 The computed KL divergence loss.
9911014 """
1015+ # Detach main model components from the computational graph.
1016+ # The indexer should match the main model, but the main model should not be influenced
1017+ # by the indexer's learning progress via this loss in sparse training stage.
1018+ # We also apply this during the Dense Warm-up stage to save compute and memory.
1019+ query = jax .lax .stop_gradient (query )
1020+ key = jax .lax .stop_gradient (key )
1021+
9921022 # Compute attention scores: [b, t, h, d] @ [b, s, h, d] -> [b, h, t, s]
9931023 attention_scores = jnp .einsum ("bthd, bshd -> bhts" , query , key , precision = self .config .matmul_precision )
9941024
@@ -1080,7 +1110,7 @@ def __call__(
10801110
10811111 # Indexer Logic
10821112 indexer_mask = None
1083- if self .use_sparse_indexer :
1113+ if self .use_indexer :
10841114 if model_mode != MODEL_MODE_TRAIN :
10851115 raise NotImplementedError ("Sparse indexer has not implemented for inference yet." )
10861116 # generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
@@ -1098,14 +1128,14 @@ def __call__(
10981128 attention_mask = attention_mask ,
10991129 )
11001130
1101- if self .config .indexer_loss_scaling_factor > 0.0 :
1131+ if indexer_mask is not None and self .config .indexer_loss_scaling_factor > 0.0 :
11021132 indexer_loss = self .calculate_indexer_loss (
11031133 indexer_score = indexer_score ,
11041134 query = query ,
11051135 key = key ,
11061136 attention_mask = attention_mask ,
11071137 indexer_mask = indexer_mask ,
1108- sparse_loss = self .config .sparse_indexer_loss ,
1138+ sparse_loss = self .config .indexer_sparse_training ,
11091139 scaling_factor = self .config .indexer_loss_scaling_factor ,
11101140 )
11111141 self .sow (nnx .Intermediate , "indexer_loss" , indexer_loss )
0 commit comments