@@ -207,6 +207,7 @@ class ProfilerType(str, Enum):
207207 "deepseek3-671b-2dfsdp" ,
208208 "deepseek3-test" ,
209209 "deepseek3-tiny" ,
210+ "deepseek3.2-671b" ,
210211 "kimi-k2-1t" ,
211212 "gemma-7b" ,
212213 "gemma-2b" ,
@@ -502,6 +503,15 @@ class MlaAttention(BaseModel):
502503 v_head_dim : NonNegativeInt = Field (128 , description = "Dimension of V heads in MLA." )
503504
504505
506+ class AttentionIndexer (BaseModel ):
507+ """Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer."""
508+
509+ use_sparse_indexer : bool = Field (False , description = "Whether to use sparse indexer for MLA." )
510+ index_head_dim : NonNegativeInt = Field (128 , description = "Head dim for indexer query and key." )
511+ index_n_heads : NonNegativeInt = Field (64 , description = "Number of query heads in indexer." )
512+ index_topk : NonNegativeInt = Field (2048 , description = "Number of tokens selected by the query token in indexer." )
513+
514+
505515class Llama4Attention (BaseModel ):
506516 """Configuration specific to Llama4-style models."""
507517
@@ -1686,6 +1696,7 @@ class MaxTextConfig(
16861696 Attention ,
16871697 MlaAttention ,
16881698 MoBa ,
1699+ AttentionIndexer ,
16891700 Llama4Attention ,
16901701 SplashAttention ,
16911702 PagedAttention ,
@@ -2120,6 +2131,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21202131 raise ValueError ("`local_checkpoint_period` must be > 0 for emergency checkpointing." )
21212132 if self .moba and self .attention not in ("dot_product" ):
21222133 raise ValueError ("MoBA is only supported with dot_product attention." )
2134+ if self .use_sparse_indexer :
2135+ if self .q_lora_rank == 0 :
2136+ raise NotImplementedError ("Sparse indexer has not implemented for q_lora_rank = 0." )
2137+ if self .attention not in ("dot_product" ):
2138+ raise ValueError ("Sparse indexer is only supported dot_product attention" )
21232139 if self .attention_type == AttentionType .CHUNK .value and (
21242140 not isinstance (self .chunk_attn_window_size , int ) or self .chunk_attn_window_size <= 0
21252141 ):
@@ -2259,9 +2275,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22592275 f"`python3 -m MaxText.muon_utils { self .model_name } True`"
22602276 )
22612277 if self .force_q_layout and not self .use_jax_splash :
2262- raise ValueError (
2263- "`force_q_layout` can only be true if `use_jax_splash` is also true."
2264- )
2278+ raise ValueError ("`force_q_layout` can only be true if `use_jax_splash` is also true." )
22652279
22662280 # I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
22672281 # Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
0 commit comments