Skip to content

Commit 95d5969

Browse files
committed
DeepSeek3.2: Onboard sparse attention
1 parent c636924 commit 95d5969

11 files changed

Lines changed: 1417 additions & 22 deletions

pytest.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ addopts =
1414
--ignore=tests/unit/gemma3_layers_test.py
1515
--ignore=tests/unit/gpt_vs_reference_test.py
1616
--ignore=tests/unit/llama4_layers_test.py
17-
--ignore=tests/unit/mla_vs_reference_test.py
17+
--ignore=tests/unit/yarn_vs_reference_test.py
1818
--ignore=tests/unit/moba_vs_reference_test.py
1919
--ignore=tests/unit/offline_engine_test.py
2020
--ignore=tests/unit/profiler_test.py
2121
--ignore=tests/unit/qwen3_omni_layers_test.py
2222
--ignore=tests/unit/qwen3_next_vs_reference_test.py
23+
--ignore=tests/unit/deepseek32_vs_reference_test.py
2324
markers =
2425
tpu_only: marks tests to be run on TPUs only
2526
gpu_only: marks tests to be run on GPUs only

src/MaxText/configs/base.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,13 @@ moba: False
328328
moba_chunk_size: 1024
329329
moba_topk: 8
330330

331+
# DeepSeek Sparse Attention (DSA)
332+
# deepseek3.2 introduces indexer in MLA
333+
use_sparse_indexer: False
334+
index_head_dim: 128
335+
index_n_heads: 64
336+
index_topk: 2048
337+
331338
# MLA parameters
332339
q_lora_rank: 0
333340
kv_lora_rank: 512
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek V3.2 - 671B
16+
# Identical to deepseek3-671b config, except adding indexer config.
17+
18+
base_emb_dim: 7168
19+
base_num_query_heads: 128
20+
base_num_kv_heads: 128
21+
base_mlp_dim: 18432
22+
base_moe_mlp_dim: 2048
23+
base_num_decoder_layers: 61
24+
first_num_dense_layers: 3
25+
mlp_activations: ["silu","linear"]
26+
vocab_size: 129280
27+
enable_dropout: False
28+
logits_via_embedding: False
29+
normalization_layer_epsilon: 1.0e-6
30+
num_experts: 256
31+
num_experts_per_tok: 8
32+
shared_experts: 1
33+
routed_scaling_factor: 2.5
34+
routed_score_func: "sigmoid"
35+
routed_bias: True
36+
decoder_block: "deepseek"
37+
# MLA
38+
attention_type: "mla"
39+
q_lora_rank: 1536
40+
kv_lora_rank: 512
41+
qk_nope_head_dim: 128
42+
qk_rope_head_dim: 64
43+
v_head_dim: 128
44+
# RoPE
45+
mscale: 1.0
46+
rope_type: "yarn"
47+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
48+
max_position_embeddings: 163840
49+
original_max_position_embeddings: 4096
50+
rope_factor: 40
51+
beta_fast: 32
52+
rope_interleave: True
53+
rope_truncate: True
54+
rope_attention_scaling: False
55+
# Indexer for DeepSeek Sparse Attention
56+
use_sparse_indexer: True
57+
index_n_heads: 64
58+
index_head_dim: 128
59+
index_topk: 2048

src/MaxText/configs/types.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ class ProfilerType(str, Enum):
207207
"deepseek3-671b-2dfsdp",
208208
"deepseek3-test",
209209
"deepseek3-tiny",
210+
"deepseek3.2-671b",
210211
"kimi-k2-1t",
211212
"gemma-7b",
212213
"gemma-2b",
@@ -502,6 +503,15 @@ class MlaAttention(BaseModel):
502503
v_head_dim: NonNegativeInt = Field(128, description="Dimension of V heads in MLA.")
503504

504505

506+
class AttentionIndexer(BaseModel):
507+
"""Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer."""
508+
509+
use_sparse_indexer: bool = Field(False, description="Whether to use sparse indexer for MLA.")
510+
index_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.")
511+
index_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.")
512+
index_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")
513+
514+
505515
class Llama4Attention(BaseModel):
506516
"""Configuration specific to Llama4-style models."""
507517

@@ -1686,6 +1696,7 @@ class MaxTextConfig(
16861696
Attention,
16871697
MlaAttention,
16881698
MoBa,
1699+
AttentionIndexer,
16891700
Llama4Attention,
16901701
SplashAttention,
16911702
PagedAttention,
@@ -2120,6 +2131,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21202131
raise ValueError("`local_checkpoint_period` must be > 0 for emergency checkpointing.")
21212132
if self.moba and self.attention not in ("dot_product"):
21222133
raise ValueError("MoBA is only supported with dot_product attention.")
2134+
if self.use_sparse_indexer:
2135+
if self.q_lora_rank == 0:
2136+
raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.")
2137+
if self.attention not in ("dot_product"):
2138+
raise ValueError("Sparse indexer is only supported dot_product attention")
21232139
if self.attention_type == AttentionType.CHUNK.value and (
21242140
not isinstance(self.chunk_attn_window_size, int) or self.chunk_attn_window_size <= 0
21252141
):
@@ -2259,9 +2275,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22592275
f"`python3 -m MaxText.muon_utils {self.model_name} True`"
22602276
)
22612277
if self.force_q_layout and not self.use_jax_splash:
2262-
raise ValueError(
2263-
"`force_q_layout` can only be true if `use_jax_splash` is also true."
2264-
)
2278+
raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.")
22652279

22662280
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
22672281
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.

0 commit comments

Comments
 (0)