Skip to content

Commit cc0d3ae

Browse files
Merge pull request #3415 from AI-Hypercomputer:indexer_train_strategy
PiperOrigin-RevId: 887003047
2 parents 41a4e9d + 0b55a28 commit cc0d3ae

15 files changed

Lines changed: 400 additions & 107 deletions

src/maxtext/configs/base.yml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -356,15 +356,15 @@ moba_topk: 8
356356

357357
# DeepSeek Sparse Attention (DSA)
358358
# deepseek3.2 introduces indexer in MLA
359-
use_sparse_indexer: False
360-
index_head_dim: 128
361-
index_n_heads: 64
362-
index_topk: 2048
363-
# Determines the token selection strategy for indexer loss:
364-
# - False: Uses all tokens (Dense Warm-up).
365-
# - True: Uses only top-k tokens (Sparse Training).
359+
use_indexer: False
360+
indexer_head_dim: 128
361+
indexer_n_heads: 64
362+
indexer_topk: 2048
363+
# Determines the training strategy for the indexer:
364+
# - False (Dense Warm-up): Computes indexer loss over all tokens. Used with `trainable_parameters_mask` to freeze other model parameters.
365+
# - True (Sparse Training): Computes indexer loss over top-k tokens only and detaches the indexer input for independent optimization.
366366
# Note: This is only active when `indexer_loss_scaling_factor` > 0.
367-
sparse_indexer_loss: False
367+
indexer_sparse_training: False
368368
# Multiplier for the indexer KL divergence loss
369369
indexer_loss_scaling_factor: 0.0
370370

@@ -790,6 +790,10 @@ gradient_clipping_threshold: 1.0
790790
gradient_accumulation_steps: 1
791791

792792
opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon"
793+
# List of parameter names/patterns to train.
794+
# If non-empty, all other parameters will be frozen. Example: ['.*indexer.*'].
795+
# If empty (default), all parameters are trained.
796+
trainable_parameters_mask: []
793797

794798
# AdamW optimizer parameters
795799
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2

src/maxtext/configs/models/deepseek-custom.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ rope_interleave: True
5656
rope_truncate: True
5757
rope_attention_scaling: False
5858
# Indexer for DeepSeek Sparse Attention
59-
use_sparse_indexer: True
60-
index_n_heads: 16 # Reduced from 64
61-
index_head_dim: 64 # Reduced from 128
62-
index_topk: 256 # Reduced from 2048
59+
use_indexer: True
60+
indexer_n_heads: 16 # Reduced from 64
61+
indexer_head_dim: 64 # Reduced from 128
62+
indexer_topk: 256 # Reduced from 2048
6363
# Hyper-connections: mHC enabled
6464
mhc_expansion_rate: 4
6565
sinkhorn_iterations: 20

src/maxtext/configs/models/deepseek3.2-671b.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ rope_interleave: True
5353
rope_truncate: True
5454
rope_attention_scaling: False
5555
# Indexer for DeepSeek Sparse Attention
56-
use_sparse_indexer: True
57-
index_n_heads: 64
58-
index_head_dim: 128
59-
index_topk: 2048
56+
use_indexer: True
57+
indexer_n_heads: 64
58+
indexer_head_dim: 128
59+
indexer_topk: 2048

src/maxtext/configs/types.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -542,11 +542,14 @@ class MlaAttention(BaseModel):
542542
class AttentionIndexer(BaseModel):
543543
"""Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer."""
544544

545-
use_sparse_indexer: bool = Field(False, description="Whether to use sparse indexer for MLA.")
546-
index_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.")
547-
index_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.")
548-
index_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")
549-
sparse_indexer_loss: bool = Field(False, description="Determines the token selection strategy for indexer loss.")
545+
use_indexer: bool = Field(False, description="Whether to use sparse indexer for MLA.")
546+
indexer_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.")
547+
indexer_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.")
548+
indexer_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")
549+
indexer_sparse_training: bool = Field(
550+
False,
551+
description="Determines the training strategy for the indexer: Dense Warm-up or Sparse Training stage.",
552+
)
550553
indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.")
551554

552555

@@ -1182,6 +1185,13 @@ class Optimizer(BaseModel):
11821185
ge=-1,
11831186
description="Total steps for the LR schedule. -1 defaults to `steps`.",
11841187
)
1188+
trainable_parameters_mask: list[str] = Field(
1189+
default_factory=list,
1190+
description=(
1191+
"List of parameter names/patterns to train. If non-empty, all other parameters will be frozen, "
1192+
"example: ['.*indexer.*']. If empty (default), all parameters are trained."
1193+
),
1194+
)
11851195

11861196

11871197
class AdamW(BaseModel):
@@ -2385,7 +2395,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23852395
raise ValueError("`local_checkpoint_period` must be > 0 for emergency checkpointing.")
23862396
if self.moba and self.attention not in ("dot_product"):
23872397
raise ValueError("MoBA is only supported with dot_product attention.")
2388-
if self.use_sparse_indexer:
2398+
if self.use_indexer:
23892399
if self.q_lora_rank == 0:
23902400
raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.")
23912401
supports_dot_product = self.attention == "dot_product"

src/maxtext/layers/attention_mla.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ def __init__(
110110
self.dtype = config.dtype
111111
self.weight_dtype = config.weight_dtype
112112

113-
self.n_heads = config.index_n_heads
114-
self.head_dim = config.index_head_dim
115-
self.index_topk = config.index_topk
113+
self.n_heads = config.indexer_n_heads
114+
self.head_dim = config.indexer_head_dim
115+
self.indexer_topk = config.indexer_topk
116116
self.emb_dim = config.emb_dim
117117
self.rope_head_dim = config.qk_rope_head_dim
118118
self.q_lora_rank = config.q_lora_rank
@@ -180,13 +180,13 @@ def apply_partial_rope(
180180
2. Input Layout: Indexer uses concatenated layout (interleave=False), whereas MLA uses interleaved (interleave=True).
181181
182182
Args:
183-
inputs: Input array of shape [batch, seqlen, index_n_heads, index_head_dim].
183+
inputs: Input array of shape [batch, seqlen, indexer_n_heads, indexer_head_dim].
184184
positions: Position array of shape [batch, seqlen].
185185
186186
Returns:
187-
Array with partial RoPE applied, with shape [batch, seqlen, index_n_heads, index_head_dim]
187+
Array with partial RoPE applied, with shape [batch, seqlen, indexer_n_heads, indexer_head_dim]
188188
"""
189-
# index_head_dim -> [rope_head_dim, index_head_dim - rope_head_dim]
189+
# indexer_head_dim -> [rope_head_dim, indexer_head_dim - rope_head_dim]
190190
x_pe, x_nope = jnp.split(inputs, [self.rope_head_dim], axis=-1)
191191
# x_pe [B, S, H, rope_head_dim], positions [B, S]
192192
x_pe = self.rotary_embedding(x_pe, position=inputs_positions)
@@ -256,14 +256,37 @@ def __call__(
256256
b: Batch size
257257
t: Query Sequence Length (Target), note t = s here
258258
s: Key/Value Sequence Length (Source)
259-
h: Number of Indexer Heads (index_n_heads)
260-
d: Indexer Head Dimension (index_head_dim)
259+
h: Number of Indexer Heads (indexer_n_heads)
260+
d: Indexer Head Dimension (indexer_head_dim)
261261
"""
262262
# NOTE: If sequence length <= topk, indexer always selects all tokens.
263-
if self.config.max_target_length <= self.index_topk:
263+
if self.config.max_target_length <= self.indexer_topk:
264264
return None, None, None
265265

266266
bsz, seqlen, _ = inputs_q.shape # s = t = seqlen
267+
# ==============================================================================
268+
# Gradient Isolation Strategy: Main Model vs. Indexer
269+
# ==============================================================================
270+
# This creates a barrier to train both components independently, and applies
271+
# for both Dense Warm-up and Sparse Training stages:
272+
#
273+
# Forward Pass:
274+
# - The Indexer receives a detached copy of the inputs (via `stop_gradient`)
275+
# to independently calculate its scores and `indexer_loss`.
276+
#
277+
# Backward Pass (Main Model):
278+
# - The main model optimizes its weights based solely on the LM loss.
279+
# - The `indexer_mask` in the Attention layer prevents gradients from the main
280+
# loss from flowing into the Indexer's weights.
281+
#
282+
# Backward Pass (Indexer):
283+
# - Gradients from the `indexer_loss` flow back to update the Indexer's weights.
284+
# - The `stop_gradient` applied to the inputs acts as a mathematical wall, dropping
285+
# gradients to 0.0 and preventing the Indexer loss from altering the main model's
286+
# earlier layers.
287+
inputs_q = jax.lax.stop_gradient(inputs_q)
288+
low_rank_q = jax.lax.stop_gradient(low_rank_q)
289+
inputs_kv = jax.lax.stop_gradient(inputs_kv)
267290

268291
# Query Processing: Project from Latent low_rank_q
269292
q = self.wq_b(low_rank_q) # [b, t, q_lora_rank] -> [b, t, h * d]
@@ -295,7 +318,7 @@ def __call__(
295318
indexer_score += attention_mask
296319

297320
# TopK selection based on index score
298-
_, topk_indices = jax.lax.top_k(indexer_score, k=self.index_topk) # topk_indices [b, t, k]
321+
_, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk) # topk_indices [b, t, k]
299322

300323
# Create Sparse Index Mask: 0 and large negatives
301324
indexer_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s]
@@ -607,8 +630,8 @@ def __init__(
607630
)
608631

609632
# Initialize Indexer
610-
self.use_sparse_indexer = config.use_sparse_indexer
611-
if self.use_sparse_indexer:
633+
self.use_indexer = config.use_indexer
634+
if self.use_indexer:
612635
# Need two versions of rope.
613636
# MLA applies yarn with interleave layout.
614637
# Indexer applies yarn with concatenate layout.
@@ -989,6 +1012,13 @@ def calculate_indexer_loss(
9891012
Returns:
9901013
The computed KL divergence loss.
9911014
"""
1015+
# Detach main model components from the computational graph.
1016+
# The indexer should match the main model, but the main model should not be influenced
1017+
# by the indexer's learning progress via this loss in sparse training stage.
1018+
# We also apply this during the Dense Warm-up stage to save compute and memory.
1019+
query = jax.lax.stop_gradient(query)
1020+
key = jax.lax.stop_gradient(key)
1021+
9921022
# Compute attention scores: [b, t, h, d] @ [b, s, h, d] -> [b, h, t, s]
9931023
attention_scores = jnp.einsum("bthd, bshd -> bhts", query, key, precision=self.config.matmul_precision)
9941024

@@ -1080,7 +1110,7 @@ def __call__(
10801110

10811111
# Indexer Logic
10821112
indexer_mask = None
1083-
if self.use_sparse_indexer:
1113+
if self.use_indexer:
10841114
if model_mode != MODEL_MODE_TRAIN:
10851115
raise NotImplementedError("Sparse indexer has not implemented for inference yet.")
10861116
# generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
@@ -1098,14 +1128,14 @@ def __call__(
10981128
attention_mask=attention_mask,
10991129
)
11001130

1101-
if self.config.indexer_loss_scaling_factor > 0.0:
1131+
if indexer_mask is not None and self.config.indexer_loss_scaling_factor > 0.0:
11021132
indexer_loss = self.calculate_indexer_loss(
11031133
indexer_score=indexer_score,
11041134
query=query,
11051135
key=key,
11061136
attention_mask=attention_mask,
11071137
indexer_mask=indexer_mask,
1108-
sparse_loss=self.config.sparse_indexer_loss,
1138+
sparse_loss=self.config.indexer_sparse_training,
11091139
scaling_factor=self.config.indexer_loss_scaling_factor,
11101140
)
11111141
self.sow(nnx.Intermediate, "indexer_loss", indexer_loss)

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1415,7 +1415,7 @@ def wrap_flash_attention(
14151415
decoder_segment_ids_tuple = None
14161416

14171417
if self.config.use_tokamax_splash:
1418-
if self.config.use_sparse_indexer and indexer_mask is not None:
1418+
if self.config.use_indexer and indexer_mask is not None:
14191419
# Construct the splash kernel call with dynamic mask
14201420
def dynamic_mask_splash_kernel(q, k, v, segment, sinks, indexer_mask):
14211421
splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha(

src/maxtext/layers/decoders.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,10 @@ def __call__(
10811081
# When invoking from vLLM with RPA attention, logit computation is deferred to a later stage.
10821082
if cfg.attention == "vllm_rpa":
10831083
logits = None
1084-
1084+
# When in the Indexer Dense Warm-up stage, skip the expensive output head projection
1085+
# for efficiency, as the main model is frozen and the LM loss is not needed.
1086+
elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN:
1087+
logits = None
10851088
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
10861089
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
10871090
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:

src/maxtext/optimizers/optimizers.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,35 @@
2424
from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers
2525

2626

27-
def get_adamw_mask(config):
28-
"""Create a mask function for AdamW optimizer to exclude certain parameters from weight decay."""
29-
if not getattr(config, "adamw_mask", None):
27+
def _get_path_mask_fn(patterns, match_returns_true=True):
28+
"""Helper to create a mask function from a list of regex patterns."""
29+
if not patterns:
3030
return None
3131

32-
compiled_patterns = [re.compile(pattern) for pattern in config.adamw_mask]
32+
compiled_patterns = [re.compile(pattern) for pattern in patterns]
3333

3434
def mask_fn(params):
35-
def _is_decayed(path, _):
35+
def _is_masked(path, _):
3636
# Join path keys into a single string for pattern matching (e.g., "layer1/bias")
37-
path_str = "/".join(str(getattr(p, "key", getattr(p, "idx", getattr(p, "name", p)))) for p in path)
38-
# If any pattern in adamw_mask matches the path, exclude from weight decay (return False).
39-
# Otherwise, apply weight decay (return True).
40-
return not any(pattern.search(path_str) for pattern in compiled_patterns)
37+
path_str = jax.tree_util.keystr(path, simple=True, separator="/")
38+
matched = any(pattern.search(path_str) for pattern in compiled_patterns)
39+
return matched if match_returns_true else not matched
4140

42-
return jax.tree_util.tree_map_with_path(_is_decayed, params)
41+
return jax.tree_util.tree_map_with_path(_is_masked, params)
4342

4443
return mask_fn
4544

4645

46+
def get_adamw_mask(config):
47+
"""Create a mask function for AdamW optimizer to exclude certain parameters from weight decay."""
48+
return _get_path_mask_fn(getattr(config, "adamw_mask", None), match_returns_true=False)
49+
50+
4751
def get_optimizer(config, learning_rate_schedule, model=None):
4852
"""Create optimizer."""
4953
if config.opt_type == "adamw":
5054
# Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
51-
return optax.adamw(
55+
base_opt = optax.adamw(
5256
learning_rate_schedule,
5357
b1=config.adam_b1,
5458
b2=config.adam_b2,
@@ -59,7 +63,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
5963
mask=get_adamw_mask(config),
6064
)
6165
elif config.opt_type == "adam_pax":
62-
return adam_pax(
66+
base_opt = adam_pax(
6367
learning_rate_schedule,
6468
beta1=config.adam_b1,
6569
beta2=config.adam_b2,
@@ -69,7 +73,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
6973
mask=get_adamw_mask(config),
7074
)
7175
elif config.opt_type == "sgd":
72-
return optax.sgd(learning_rate_schedule)
76+
base_opt = optax.sgd(learning_rate_schedule)
7377
elif config.opt_type == "muon":
7478
# extract muon dimension number from model structure
7579
if model is not None:
@@ -92,10 +96,26 @@ def get_optimizer(config, learning_rate_schedule, model=None):
9296
"adam_eps_root": config.adam_eps_root,
9397
"adam_weight_decay": config.adam_weight_decay,
9498
}
95-
return muon(**muon_kwargs)
99+
base_opt = muon(**muon_kwargs)
96100
else:
97101
raise ValueError(f"{config.opt_type=} is not a supported.")
98102

103+
# If a whitelist of trainable parameters is provided, freeze everything else.
104+
# When trainable_parameters_mask is empty, freeze_mask_fn is None and all parameters are trained.
105+
trainable_patterns = getattr(config, "trainable_parameters_mask", None)
106+
freeze_mask_fn = _get_path_mask_fn(trainable_patterns, match_returns_true=False)
107+
if freeze_mask_fn is not None:
108+
# Use optax.multi_transform to explicitly map frozen parameters to a stateless set_to_zero() optimizer.
109+
# If we simply wrapped base_opt in optax.masked() or chained it, Optax would still allocate
110+
# massive states (momentum, variance) for the entire model before zeroing the updates.
111+
# By using multi_transform, only the trainable parameters get states allocated.
112+
return optax.multi_transform(
113+
{"trainable": base_opt, "frozen": optax.set_to_zero()},
114+
lambda params: jax.tree_util.tree_map(lambda x: "frozen" if x else "trainable", freeze_mask_fn(params)),
115+
)
116+
117+
return base_opt
118+
99119

100120
def adam_pax(
101121
learning_rate_fn: optax.Schedule,

0 commit comments

Comments
 (0)