Skip to content

Commit 8bc46e6

Browse files
committed
Add flops calculation for DeepSeek v3.2
1 parent 96f1375 commit 8bc46e6

3 files changed

Lines changed: 164 additions & 14 deletions

File tree

src/MaxText/layers/attention_mla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __call__(
226226
2. K = RoPE(Norm(Wk @ X))
227227
3. Logits = ReLU(Q @ K.T) # Pairwise similarity
228228
4. Head_Weights = (W_proj @ X) * scale # Dynamic head importance, scale for stability
229-
5. Score = Sum_head(Logits * Head_Weights) # Aggregate heads
229+
5. Score = Logits @ Head_Weights # Aggregate heads
230230
6. Indices = ArgTopk(Score)
231231
232232
Args:
@@ -281,7 +281,7 @@ def __call__(
281281
# Weights scaling affect index_score, but does not affect topk_indices. Keep scaling for numerical stability.
282282
# https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480
283283
weights = weights * (self.n_heads**-0.5) * self.softmax_scale
284-
# Weighted sum over head: sum_h(logits * weights)
284+
# Aggregate head-wise logits: logits @ weights
285285
index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]
286286

287287
# Apply attention mask before TopK

src/MaxText/maxtext_utils.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,86 @@ def calculate_llama4_attention_tflops(config):
319319
return attention_tflops
320320

321321

322+
def calculate_indexer_mask_ratio(index_topk, max_target_length):
323+
"""
324+
Calculates the sparse-to-dense ratio for Indexer TFLOPs.
325+
326+
The indexer evaluates all previous tokens in a causal manner until it hits
327+
the Top-K limit.
328+
329+
Visual Representation (T=8, K=4):
330+
Key (S) ->
331+
Q1 [X . . . . . . .] <- 1 token scored
332+
Q2 [X X . . . . . .] <- 2 tokens scored
333+
Q3 [X X X . . . . .] <- 3 tokens scored
334+
Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
335+
Q5 [X X X . X . . .] <- 4 tokens scored
336+
Q6 [X X . X . X . .] <- 4 tokens scored
337+
Q7 [X . X X . . X .] <- 4 tokens scored
338+
Q8 [X X . X . . . X] <- 4 tokens scored
339+
340+
For MFU calculation:
341+
342+
Visual Representation (T=8, K=4):
343+
Key (S) ->
344+
Q1 [X . . . . . . .] <- 1 token scored
345+
Q2 [X X . . . . . .] <- 2 tokens scored
346+
Q3 [X X X . . . . .] <- 3 tokens scored
347+
Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
348+
Q5 [X X X X . . . .] <- 4 tokens scored
349+
Q6 [X X X X . . . .] <- 4 tokens scored
350+
Q7 [X X X X . . . .] <- 4 tokens scored
351+
Q8 [X X X X . . . .] <- 4 tokens scored
352+
353+
Mathematical Calculation:
354+
- Triangle (Phase 1: 1 to K): K^2 / 2
355+
- Rectangle (Phase 2: K+1 to T): (T - K) * K
356+
- Total Active Area = TK - K^2 / 2
357+
- Dense Area = T^2
358+
359+
Ratio = (TK - 0.5*K^2) / T^2 => (K/T) - 0.5*(K/T)^2
360+
"""
361+
362+
T = float(max_target_length)
363+
K = float(index_topk)
364+
365+
ratio = K / T
366+
mask_multiplier = ratio - (0.5 * ratio**2)
367+
return mask_multiplier
368+
369+
370+
def calculate_indexer_tflops_per_device(config):
371+
"""Calculates TFLOPs for the DeepSeek Lightning Indexer (handles causal reduction)."""
372+
batch_len = config.per_device_batch_size * config.max_target_length
373+
374+
# 1. Calculate projections flops
375+
# Query: [batch, seq, q_lora_rank] @ [q_lora_rank, index_n_heads, index_head_dim]
376+
q_flops = 2 * batch_len * config.q_lora_rank * config.index_n_heads * config.index_head_dim
377+
# Key: [batch, seq, emb_dim] @ [emb_dim, index_head_dim]
378+
k_flops = 2 * batch_len * config.emb_dim * config.index_head_dim
379+
# Head weight: [batch, seq, emb_dim] @ [emb_dim, index_n_heads]
380+
head_weight_flops = 2 * batch_len * config.emb_dim * config.index_n_heads
381+
proj_flops = q_flops + k_flops + head_weight_flops
382+
383+
# 2. Calculate index score flops
384+
# QK product [batch, seq, index_n_heads, index_head_dim] @ [batch, seq, index_head_dim]
385+
# --> [batch, seq, seq, index_n_heads]
386+
qk_product_flops = 2 * batch_len * config.max_target_length * config.index_n_heads * config.index_head_dim
387+
# Aggregate heads [batch, seq, seq, index_n_heads] @ [batch, seq, index_n_heads]
388+
head_reduction_flops = 2 * batch_len * config.max_target_length * config.index_n_heads
389+
# Apply causal mask: Divide by 2 to account for triangular interactions
390+
# The mask restricts the indexer's search space prior to Top-K filtering
391+
scoring_flops = (qk_product_flops + head_reduction_flops) / 2
392+
393+
return proj_flops, scoring_flops
394+
395+
322396
def calculate_mla_tflops_per_device(config):
323-
"""Calculate Multi-Head Latent Attention TFLOP"""
397+
"""Calculate Multi-Head Latent Attention TFLOP (handles causal reduction)"""
324398
batch_len = config.per_device_batch_size * config.max_target_length
325399
qk_head_dim_sum = config.qk_nope_head_dim + config.qk_rope_head_dim
326-
# calculate mla query projection
400+
401+
# 1. calculate mla query projection
327402
if config.q_lora_rank == 0:
328403
q_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * qk_head_dim_sum
329404
else:
@@ -333,7 +408,8 @@ def calculate_mla_tflops_per_device(config):
333408
* batch_len
334409
* (config.emb_dim * config.q_lora_rank + config.q_lora_rank * config.num_query_heads * qk_head_dim_sum)
335410
)
336-
# calculate mla kv projection with down and up flops
411+
412+
# 2. calculate mla kv projection
337413
kv_flops = (
338414
2
339415
* batch_len
@@ -344,9 +420,31 @@ def calculate_mla_tflops_per_device(config):
344420
)
345421
qkv_flops = q_flops + kv_flops
346422

347-
attention_flops = (
348-
2 * batch_len * config.max_target_length * config.num_query_heads * (qk_head_dim_sum + config.v_head_dim)
349-
)
423+
# 3. calculate attention
424+
if config.use_sparse_indexer and config.max_target_length > config.index_topk:
425+
# get indexer flops
426+
indexer_proj_flops, indexer_scoring_flops = calculate_indexer_tflops_per_device(config)
427+
qkv_flops += indexer_proj_flops
428+
429+
# calculate the proportion of the T x T causal matrix that the Indexer actually explores
430+
# this follows the area: (TK - 0.5*K^2) / T^2 (T: max_target_length, K: index_topk)
431+
multiplier = calculate_indexer_mask_ratio(config.index_topk, config.max_target_length)
432+
attention_flops = (
433+
2
434+
* batch_len
435+
* config.max_target_length
436+
* config.num_query_heads
437+
* (qk_head_dim_sum + config.v_head_dim)
438+
* multiplier
439+
)
440+
attention_flops += indexer_scoring_flops
441+
else:
442+
# standard MLA & max_target_length <= index_topk in sparse indexer
443+
# in both cases, the indexer is bypassed as the causal mask remains the efficient representation
444+
attention_flops = (
445+
2 * batch_len * config.max_target_length * config.num_query_heads * (qk_head_dim_sum + config.v_head_dim)
446+
)
447+
attention_flops = attention_flops / 2
350448
projection_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * config.v_head_dim
351449
return qkv_flops, attention_flops, projection_flops
352450

@@ -546,7 +644,7 @@ def calculate_tflops_training_per_device(config, log=True):
546644

547645
# Attention flops
548646
if config.attention_type == "mla":
549-
qkv_flops, noncausal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
647+
qkv_flops, causal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
550648
else:
551649
qkv_flops = (
552650
2
@@ -568,11 +666,11 @@ def calculate_tflops_training_per_device(config, log=True):
568666
* config.head_dim
569667
)
570668

571-
# Divide attention flops by 2 due to causal mask
572-
# References:
573-
# NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
574-
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
575-
causal_attention_flops = noncausal_attention_flops / 2
669+
# Divide attention flops by 2 due to causal mask
670+
# References:
671+
# NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
672+
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
673+
causal_attention_flops = noncausal_attention_flops / 2
576674

577675
# Embedding flops
578676
embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size

tests/unit/flop_calculation_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,55 @@ def test_gpt_oss_20b_flops(self):
292292
)
293293
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
294294
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)
295+
296+
@pytest.mark.cpu_only
297+
def test_deepseek32_671b_flops(self):
298+
"""Test DeepSeek3.2-671b FLops calculation"""
299+
kwargs = {
300+
# Model bases
301+
"model_name": "deepseek3.2-671b",
302+
"override_model_config": True,
303+
# Core workload parameters
304+
"per_device_batch_size": 4,
305+
"max_target_length": 4096,
306+
"num_experts": 256,
307+
"num_experts_per_tok": 8,
308+
"shared_experts": 1,
309+
# Model dimensions
310+
"base_emb_dim": 7168,
311+
"base_num_query_heads": 128,
312+
"base_num_kv_heads": 128,
313+
"base_mlp_dim": 18432,
314+
"base_moe_mlp_dim": 2048,
315+
"base_num_decoder_layers": 61,
316+
"first_num_dense_layers": 3,
317+
"mlp_activations": ["silu", "linear"],
318+
"vocab_size": 129280,
319+
# MLA
320+
"q_lora_rank": 1536,
321+
"kv_lora_rank": 512,
322+
"qk_nope_head_dim": 128,
323+
"qk_rope_head_dim": 64,
324+
"v_head_dim": 128,
325+
"skip_jax_distributed_system": True,
326+
# Indexer for DeepSeek Sparse Attention
327+
"use_sparse_indexer": True,
328+
"index_n_heads": 64,
329+
"index_head_dim": 128,
330+
"index_topk": 2048,
331+
# TODO(ranran): remove after flash attention is supported
332+
"attention": "dot_product",
333+
}
334+
B = kwargs["per_device_batch_size"]
335+
S = kwargs["max_target_length"]
336+
attention_flops = self.compute_deepseek_attention_flops_per_device(kwargs)
337+
# deepseek3-671b has ~37B active parameters
338+
# https://arxiv.org/pdf/2412.19437
339+
golden_param_size = 37e9
340+
golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops
341+
cfg = pyconfig.initialize(
342+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
343+
**kwargs,
344+
)
345+
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
346+
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)

0 commit comments

Comments
 (0)