@@ -319,11 +319,86 @@ def calculate_llama4_attention_tflops(config):
319319 return attention_tflops
320320
321321
322+ def calculate_indexer_mask_ratio (index_topk , max_target_length ):
323+ """
324+ Calculates the sparse-to-dense ratio for Indexer TFLOPs.
325+
326+ The indexer evaluates all previous tokens in a causal manner until it hits
327+ the Top-K limit.
328+
329+ Visual Representation (T=8, K=4):
330+ Key (S) ->
331+ Q1 [X . . . . . . .] <- 1 token scored
332+ Q2 [X X . . . . . .] <- 2 tokens scored
333+ Q3 [X X X . . . . .] <- 3 tokens scored
334+ Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
335+ Q5 [X X X . X . . .] <- 4 tokens scored
336+ Q6 [X X . X . X . .] <- 4 tokens scored
337+ Q7 [X . X X . . X .] <- 4 tokens scored
338+ Q8 [X X . X . . . X] <- 4 tokens scored
339+
340+ For MFU calculation:
341+
342+ Visual Representation (T=8, K=4):
343+ Key (S) ->
344+ Q1 [X . . . . . . .] <- 1 token scored
345+ Q2 [X X . . . . . .] <- 2 tokens scored
346+ Q3 [X X X . . . . .] <- 3 tokens scored
347+ Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
348+ Q5 [X X X X . . . .] <- 4 tokens scored
349+ Q6 [X X X X . . . .] <- 4 tokens scored
350+ Q7 [X X X X . . . .] <- 4 tokens scored
351+ Q8 [X X X X . . . .] <- 4 tokens scored
352+
353+ Mathematical Calculation:
354+ - Triangle (Phase 1: 1 to K): K^2 / 2
355+ - Rectangle (Phase 2: K+1 to T): (T - K) * K
356+ - Total Active Area = TK - K^2 / 2
357+ - Dense Area = T^2
358+
359+ Ratio = (TK - 0.5*K^2) / T^2 => (K/T) - 0.5*(K/T)^2
360+ """
361+
362+ T = float (max_target_length )
363+ K = float (index_topk )
364+
365+ ratio = K / T
366+ mask_multiplier = ratio - (0.5 * ratio ** 2 )
367+ return mask_multiplier
368+
369+
370+ def calculate_indexer_tflops_per_device (config ):
371+ """Calculates TFLOPs for the DeepSeek Lightning Indexer (handles causal reduction)."""
372+ batch_len = config .per_device_batch_size * config .max_target_length
373+
374+ # 1. Calculate projections flops
375+ # Query: [batch, seq, q_lora_rank] @ [q_lora_rank, index_n_heads, index_head_dim]
376+ q_flops = 2 * batch_len * config .q_lora_rank * config .index_n_heads * config .index_head_dim
377+ # Key: [batch, seq, emb_dim] @ [emb_dim, index_head_dim]
378+ k_flops = 2 * batch_len * config .emb_dim * config .index_head_dim
379+ # Head weight: [batch, seq, emb_dim] @ [emb_dim, index_n_heads]
380+ head_weight_flops = 2 * batch_len * config .emb_dim * config .index_n_heads
381+ proj_flops = q_flops + k_flops + head_weight_flops
382+
383+ # 2. Calculate index score flops
384+ # QK product [batch, seq, index_n_heads, index_head_dim] @ [batch, seq, index_head_dim]
385+ # --> [batch, seq, seq, index_n_heads]
386+ qk_product_flops = 2 * batch_len * config .max_target_length * config .index_n_heads * config .index_head_dim
387+ # Aggregate heads [batch, seq, seq, index_n_heads] @ [batch, seq, index_n_heads]
388+ head_reduction_flops = 2 * batch_len * config .max_target_length * config .index_n_heads
389+ # Apply causal mask: Divide by 2 to account for triangular interactions
390+ # The mask restricts the indexer's search space prior to Top-K filtering
391+ scoring_flops = (qk_product_flops + head_reduction_flops ) / 2
392+
393+ return proj_flops , scoring_flops
394+
395+
322396def calculate_mla_tflops_per_device (config ):
323- """Calculate Multi-Head Latent Attention TFLOP"""
397+ """Calculate Multi-Head Latent Attention TFLOP (handles causal reduction) """
324398 batch_len = config .per_device_batch_size * config .max_target_length
325399 qk_head_dim_sum = config .qk_nope_head_dim + config .qk_rope_head_dim
326- # calculate mla query projection
400+
401+ # 1. calculate mla query projection
327402 if config .q_lora_rank == 0 :
328403 q_flops = 2 * batch_len * config .emb_dim * config .num_query_heads * qk_head_dim_sum
329404 else :
@@ -333,7 +408,8 @@ def calculate_mla_tflops_per_device(config):
333408 * batch_len
334409 * (config .emb_dim * config .q_lora_rank + config .q_lora_rank * config .num_query_heads * qk_head_dim_sum )
335410 )
336- # calculate mla kv projection with down and up flops
411+
412+ # 2. calculate mla kv projection
337413 kv_flops = (
338414 2
339415 * batch_len
@@ -344,9 +420,31 @@ def calculate_mla_tflops_per_device(config):
344420 )
345421 qkv_flops = q_flops + kv_flops
346422
347- attention_flops = (
348- 2 * batch_len * config .max_target_length * config .num_query_heads * (qk_head_dim_sum + config .v_head_dim )
349- )
423+ # 3. calculate attention
424+ if config .use_sparse_indexer and config .max_target_length > config .index_topk :
425+ # get indexer flops
426+ indexer_proj_flops , indexer_scoring_flops = calculate_indexer_tflops_per_device (config )
427+ qkv_flops += indexer_proj_flops
428+
429+ # calculate the proportion of the T x T causal matrix that the Indexer actually explores
430+ # this follows the area: (TK - 0.5*K^2) / T^2 (T: max_target_length, K: index_topk)
431+ multiplier = calculate_indexer_mask_ratio (config .index_topk , config .max_target_length )
432+ attention_flops = (
433+ 2
434+ * batch_len
435+ * config .max_target_length
436+ * config .num_query_heads
437+ * (qk_head_dim_sum + config .v_head_dim )
438+ * multiplier
439+ )
440+ attention_flops += indexer_scoring_flops
441+ else :
442+ # standard MLA & max_target_length <= index_topk in sparse indexer
443+ # in both cases, the indexer is bypassed as the causal mask remains the efficient representation
444+ attention_flops = (
445+ 2 * batch_len * config .max_target_length * config .num_query_heads * (qk_head_dim_sum + config .v_head_dim )
446+ )
447+ attention_flops = attention_flops / 2
350448 projection_flops = 2 * batch_len * config .emb_dim * config .num_query_heads * config .v_head_dim
351449 return qkv_flops , attention_flops , projection_flops
352450
@@ -546,7 +644,7 @@ def calculate_tflops_training_per_device(config, log=True):
546644
547645 # Attention flops
548646 if config .attention_type == "mla" :
549- qkv_flops , noncausal_attention_flops , projection_flops = calculate_mla_tflops_per_device (config )
647+ qkv_flops , causal_attention_flops , projection_flops = calculate_mla_tflops_per_device (config )
550648 else :
551649 qkv_flops = (
552650 2
@@ -568,11 +666,11 @@ def calculate_tflops_training_per_device(config, log=True):
568666 * config .head_dim
569667 )
570668
571- # Divide attention flops by 2 due to causal mask
572- # References:
573- # NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
574- # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
575- causal_attention_flops = noncausal_attention_flops / 2
669+ # Divide attention flops by 2 due to causal mask
670+ # References:
671+ # NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
672+ # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
673+ causal_attention_flops = noncausal_attention_flops / 2
576674
577675 # Embedding flops
578676 embedding_flops = 2 * config .per_device_batch_size * config .max_target_length * config .emb_dim * config .vocab_size
0 commit comments