Skip to content

Commit 15177f2

Browse files
Merge pull request #2979 from AI-Hypercomputer:flops_clean
PiperOrigin-RevId: 862324339
2 parents 5f7c76d + 8bc46e6 commit 15177f2

3 files changed

Lines changed: 164 additions & 14 deletions

File tree

src/MaxText/layers/attention_mla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __call__(
226226
2. K = RoPE(Norm(Wk @ X))
227227
3. Logits = ReLU(Q @ K.T) # Pairwise similarity
228228
4. Head_Weights = (W_proj @ X) * scale # Dynamic head importance, scale for stability
229-
5. Score = Sum_head(Logits * Head_Weights) # Aggregate heads
229+
5. Score = Logits @ Head_Weights # Aggregate heads
230230
6. Indices = ArgTopk(Score)
231231
232232
Args:
@@ -281,7 +281,7 @@ def __call__(
281281
# Weights scaling affect index_score, but does not affect topk_indices. Keep scaling for numerical stability.
282282
# https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480
283283
weights = weights * (self.n_heads**-0.5) * self.softmax_scale
284-
# Weighted sum over head: sum_h(logits * weights)
284+
# Aggregate head-wise logits: logits @ weights
285285
index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]
286286

287287
# Apply attention mask before TopK

src/MaxText/maxtext_utils.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,86 @@ def calculate_llama4_attention_tflops(config):
321321
return attention_tflops
322322

323323

324+
def calculate_indexer_mask_ratio(index_topk, max_target_length):
325+
"""
326+
Calculates the sparse-to-dense ratio for Indexer TFLOPs.
327+
328+
The indexer evaluates all previous tokens in a causal manner until it hits
329+
the Top-K limit.
330+
331+
Visual Representation (T=8, K=4):
332+
Key (S) ->
333+
Q1 [X . . . . . . .] <- 1 token scored
334+
Q2 [X X . . . . . .] <- 2 tokens scored
335+
Q3 [X X X . . . . .] <- 3 tokens scored
336+
Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
337+
Q5 [X X X . X . . .] <- 4 tokens scored
338+
Q6 [X X . X . X . .] <- 4 tokens scored
339+
Q7 [X . X X . . X .] <- 4 tokens scored
340+
Q8 [X X . X . . . X] <- 4 tokens scored
341+
342+
For MFU calculation:
343+
344+
Visual Representation (T=8, K=4):
345+
Key (S) ->
346+
Q1 [X . . . . . . .] <- 1 token scored
347+
Q2 [X X . . . . . .] <- 2 tokens scored
348+
Q3 [X X X . . . . .] <- 3 tokens scored
349+
Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
350+
Q5 [X X X X . . . .] <- 4 tokens scored
351+
Q6 [X X X X . . . .] <- 4 tokens scored
352+
Q7 [X X X X . . . .] <- 4 tokens scored
353+
Q8 [X X X X . . . .] <- 4 tokens scored
354+
355+
Mathematical Calculation:
356+
- Triangle (Phase 1: 1 to K): K^2 / 2
357+
- Rectangle (Phase 2: K+1 to T): (T - K) * K
358+
- Total Active Area = TK - K^2 / 2
359+
- Dense Area = T^2
360+
361+
Ratio = (TK - 0.5*K^2) / T^2 => (K/T) - 0.5*(K/T)^2
362+
"""
363+
364+
T = float(max_target_length)
365+
K = float(index_topk)
366+
367+
ratio = K / T
368+
mask_multiplier = ratio - (0.5 * ratio**2)
369+
return mask_multiplier
370+
371+
372+
def calculate_indexer_tflops_per_device(config):
373+
"""Calculates TFLOPs for the DeepSeek Lightning Indexer (handles causal reduction)."""
374+
batch_len = config.per_device_batch_size * config.max_target_length
375+
376+
# 1. Calculate projections flops
377+
# Query: [batch, seq, q_lora_rank] @ [q_lora_rank, index_n_heads, index_head_dim]
378+
q_flops = 2 * batch_len * config.q_lora_rank * config.index_n_heads * config.index_head_dim
379+
# Key: [batch, seq, emb_dim] @ [emb_dim, index_head_dim]
380+
k_flops = 2 * batch_len * config.emb_dim * config.index_head_dim
381+
# Head weight: [batch, seq, emb_dim] @ [emb_dim, index_n_heads]
382+
head_weight_flops = 2 * batch_len * config.emb_dim * config.index_n_heads
383+
proj_flops = q_flops + k_flops + head_weight_flops
384+
385+
# 2. Calculate index score flops
386+
# QK product [batch, seq, index_n_heads, index_head_dim] @ [batch, seq, index_head_dim]
387+
# --> [batch, seq, seq, index_n_heads]
388+
qk_product_flops = 2 * batch_len * config.max_target_length * config.index_n_heads * config.index_head_dim
389+
# Aggregate heads [batch, seq, seq, index_n_heads] @ [batch, seq, index_n_heads]
390+
head_reduction_flops = 2 * batch_len * config.max_target_length * config.index_n_heads
391+
# Apply causal mask: Divide by 2 to account for triangular interactions
392+
# The mask restricts the indexer's search space prior to Top-K filtering
393+
scoring_flops = (qk_product_flops + head_reduction_flops) / 2
394+
395+
return proj_flops, scoring_flops
396+
397+
324398
def calculate_mla_tflops_per_device(config):
325-
"""Calculate Multi-Head Latent Attention TFLOP"""
399+
"""Calculate Multi-Head Latent Attention TFLOP (handles causal reduction)"""
326400
batch_len = config.per_device_batch_size * config.max_target_length
327401
qk_head_dim_sum = config.qk_nope_head_dim + config.qk_rope_head_dim
328-
# calculate mla query projection
402+
403+
# 1. calculate mla query projection
329404
if config.q_lora_rank == 0:
330405
q_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * qk_head_dim_sum
331406
else:
@@ -335,7 +410,8 @@ def calculate_mla_tflops_per_device(config):
335410
* batch_len
336411
* (config.emb_dim * config.q_lora_rank + config.q_lora_rank * config.num_query_heads * qk_head_dim_sum)
337412
)
338-
# calculate mla kv projection with down and up flops
413+
414+
# 2. calculate mla kv projection
339415
kv_flops = (
340416
2
341417
* batch_len
@@ -346,9 +422,31 @@ def calculate_mla_tflops_per_device(config):
346422
)
347423
qkv_flops = q_flops + kv_flops
348424

349-
attention_flops = (
350-
2 * batch_len * config.max_target_length * config.num_query_heads * (qk_head_dim_sum + config.v_head_dim)
351-
)
425+
# 3. calculate attention
426+
if config.use_sparse_indexer and config.max_target_length > config.index_topk:
427+
# get indexer flops
428+
indexer_proj_flops, indexer_scoring_flops = calculate_indexer_tflops_per_device(config)
429+
qkv_flops += indexer_proj_flops
430+
431+
# calculate the proportion of the T x T causal matrix that the Indexer actually explores
432+
# this follows the area: (TK - 0.5*K^2) / T^2 (T: max_target_length, K: index_topk)
433+
multiplier = calculate_indexer_mask_ratio(config.index_topk, config.max_target_length)
434+
attention_flops = (
435+
2
436+
* batch_len
437+
* config.max_target_length
438+
* config.num_query_heads
439+
* (qk_head_dim_sum + config.v_head_dim)
440+
* multiplier
441+
)
442+
attention_flops += indexer_scoring_flops
443+
else:
444+
# standard MLA & max_target_length <= index_topk in sparse indexer
445+
# in both cases, the indexer is bypassed as the causal mask remains the efficient representation
446+
attention_flops = (
447+
2 * batch_len * config.max_target_length * config.num_query_heads * (qk_head_dim_sum + config.v_head_dim)
448+
)
449+
attention_flops = attention_flops / 2
352450
projection_flops = 2 * batch_len * config.emb_dim * config.num_query_heads * config.v_head_dim
353451
return qkv_flops, attention_flops, projection_flops
354452

@@ -548,7 +646,7 @@ def calculate_tflops_training_per_device(config, log=True):
548646

549647
# Attention flops
550648
if config.attention_type == "mla":
551-
qkv_flops, noncausal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
649+
qkv_flops, causal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
552650
else:
553651
qkv_flops = (
554652
2
@@ -570,11 +668,11 @@ def calculate_tflops_training_per_device(config, log=True):
570668
* config.head_dim
571669
)
572670

573-
# Divide attention flops by 2 due to causal mask
574-
# References:
575-
# NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
576-
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
577-
causal_attention_flops = noncausal_attention_flops / 2
671+
# Divide attention flops by 2 due to causal mask
672+
# References:
673+
# NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
674+
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
675+
causal_attention_flops = noncausal_attention_flops / 2
578676

579677
# Embedding flops
580678
embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size

tests/unit/flop_calculation_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,55 @@ def test_gpt_oss_20b_flops(self):
292292
)
293293
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
294294
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)
295+
296+
@pytest.mark.cpu_only
297+
def test_deepseek32_671b_flops(self):
298+
"""Test DeepSeek3.2-671b FLops calculation"""
299+
kwargs = {
300+
# Model bases
301+
"model_name": "deepseek3.2-671b",
302+
"override_model_config": True,
303+
# Core workload parameters
304+
"per_device_batch_size": 4,
305+
"max_target_length": 4096,
306+
"num_experts": 256,
307+
"num_experts_per_tok": 8,
308+
"shared_experts": 1,
309+
# Model dimensions
310+
"base_emb_dim": 7168,
311+
"base_num_query_heads": 128,
312+
"base_num_kv_heads": 128,
313+
"base_mlp_dim": 18432,
314+
"base_moe_mlp_dim": 2048,
315+
"base_num_decoder_layers": 61,
316+
"first_num_dense_layers": 3,
317+
"mlp_activations": ["silu", "linear"],
318+
"vocab_size": 129280,
319+
# MLA
320+
"q_lora_rank": 1536,
321+
"kv_lora_rank": 512,
322+
"qk_nope_head_dim": 128,
323+
"qk_rope_head_dim": 64,
324+
"v_head_dim": 128,
325+
"skip_jax_distributed_system": True,
326+
# Indexer for DeepSeek Sparse Attention
327+
"use_sparse_indexer": True,
328+
"index_n_heads": 64,
329+
"index_head_dim": 128,
330+
"index_topk": 2048,
331+
# TODO(ranran): remove after flash attention is supported
332+
"attention": "dot_product",
333+
}
334+
B = kwargs["per_device_batch_size"]
335+
S = kwargs["max_target_length"]
336+
attention_flops = self.compute_deepseek_attention_flops_per_device(kwargs)
337+
# deepseek3-671b has ~37B active parameters
338+
# https://arxiv.org/pdf/2412.19437
339+
golden_param_size = 37e9
340+
golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops
341+
cfg = pyconfig.initialize(
342+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
343+
**kwargs,
344+
)
345+
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
346+
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)

0 commit comments

Comments
 (0)