@@ -321,11 +321,86 @@ def calculate_llama4_attention_tflops(config):
321321 return attention_tflops
322322
323323
324+ def calculate_indexer_mask_ratio (index_topk , max_target_length ):
325+ """
326+ Calculates the sparse-to-dense ratio for Indexer TFLOPs.
327+
328+ The indexer evaluates all previous tokens in a causal manner until it hits
329+ the Top-K limit.
330+
331+ Visual Representation (T=8, K=4):
332+ Key (S) ->
333+ Q1 [X . . . . . . .] <- 1 token scored
334+ Q2 [X X . . . . . .] <- 2 tokens scored
335+ Q3 [X X X . . . . .] <- 3 tokens scored
336+ Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
337+ Q5 [X X X . X . . .] <- 4 tokens scored
338+ Q6 [X X . X . X . .] <- 4 tokens scored
339+ Q7 [X . X X . . X .] <- 4 tokens scored
340+ Q8 [X X . X . . . X] <- 4 tokens scored
341+
342+ For MFU calculation:
343+
344+ Visual Representation (T=8, K=4):
345+ Key (S) ->
346+ Q1 [X . . . . . . .] <- 1 token scored
347+ Q2 [X X . . . . . .] <- 2 tokens scored
348+ Q3 [X X X . . . . .] <- 3 tokens scored
349+ Q4 [X X X X . . . .] <- 4 tokens scored (K limit reached)
350+ Q5 [X X X X . . . .] <- 4 tokens scored
351+ Q6 [X X X X . . . .] <- 4 tokens scored
352+ Q7 [X X X X . . . .] <- 4 tokens scored
353+ Q8 [X X X X . . . .] <- 4 tokens scored
354+
355+ Mathematical Calculation:
356+ - Triangle (Phase 1: 1 to K): K^2 / 2
357+ - Rectangle (Phase 2: K+1 to T): (T - K) * K
358+ - Total Active Area = TK - K^2 / 2
359+ - Dense Area = T^2
360+
361+ Ratio = (TK - 0.5*K^2) / T^2 => (K/T) - 0.5*(K/T)^2
362+ """
363+
364+ T = float (max_target_length )
365+ K = float (index_topk )
366+
367+ ratio = K / T
368+ mask_multiplier = ratio - (0.5 * ratio ** 2 )
369+ return mask_multiplier
370+
371+
372+ def calculate_indexer_tflops_per_device (config ):
373+ """Calculates TFLOPs for the DeepSeek Lightning Indexer (handles causal reduction)."""
374+ batch_len = config .per_device_batch_size * config .max_target_length
375+
376+ # 1. Calculate projections flops
377+ # Query: [batch, seq, q_lora_rank] @ [q_lora_rank, index_n_heads, index_head_dim]
378+ q_flops = 2 * batch_len * config .q_lora_rank * config .index_n_heads * config .index_head_dim
379+ # Key: [batch, seq, emb_dim] @ [emb_dim, index_head_dim]
380+ k_flops = 2 * batch_len * config .emb_dim * config .index_head_dim
381+ # Head weight: [batch, seq, emb_dim] @ [emb_dim, index_n_heads]
382+ head_weight_flops = 2 * batch_len * config .emb_dim * config .index_n_heads
383+ proj_flops = q_flops + k_flops + head_weight_flops
384+
385+ # 2. Calculate index score flops
386+ # QK product [batch, seq, index_n_heads, index_head_dim] @ [batch, seq, index_head_dim]
387+ # --> [batch, seq, seq, index_n_heads]
388+ qk_product_flops = 2 * batch_len * config .max_target_length * config .index_n_heads * config .index_head_dim
389+ # Aggregate heads [batch, seq, seq, index_n_heads] @ [batch, seq, index_n_heads]
390+ head_reduction_flops = 2 * batch_len * config .max_target_length * config .index_n_heads
391+ # Apply causal mask: Divide by 2 to account for triangular interactions
392+ # The mask restricts the indexer's search space prior to Top-K filtering
393+ scoring_flops = (qk_product_flops + head_reduction_flops ) / 2
394+
395+ return proj_flops , scoring_flops
396+
397+
324398def calculate_mla_tflops_per_device (config ):
325- """Calculate Multi-Head Latent Attention TFLOP"""
399+ """Calculate Multi-Head Latent Attention TFLOP (handles causal reduction) """
326400 batch_len = config .per_device_batch_size * config .max_target_length
327401 qk_head_dim_sum = config .qk_nope_head_dim + config .qk_rope_head_dim
328- # calculate mla query projection
402+
403+ # 1. calculate mla query projection
329404 if config .q_lora_rank == 0 :
330405 q_flops = 2 * batch_len * config .emb_dim * config .num_query_heads * qk_head_dim_sum
331406 else :
@@ -335,7 +410,8 @@ def calculate_mla_tflops_per_device(config):
335410 * batch_len
336411 * (config .emb_dim * config .q_lora_rank + config .q_lora_rank * config .num_query_heads * qk_head_dim_sum )
337412 )
338- # calculate mla kv projection with down and up flops
413+
414+ # 2. calculate mla kv projection
339415 kv_flops = (
340416 2
341417 * batch_len
@@ -346,9 +422,31 @@ def calculate_mla_tflops_per_device(config):
346422 )
347423 qkv_flops = q_flops + kv_flops
348424
349- attention_flops = (
350- 2 * batch_len * config .max_target_length * config .num_query_heads * (qk_head_dim_sum + config .v_head_dim )
351- )
425+ # 3. calculate attention
426+ if config .use_sparse_indexer and config .max_target_length > config .index_topk :
427+ # get indexer flops
428+ indexer_proj_flops , indexer_scoring_flops = calculate_indexer_tflops_per_device (config )
429+ qkv_flops += indexer_proj_flops
430+
431+ # calculate the proportion of the T x T causal matrix that the Indexer actually explores
432+ # this follows the area: (TK - 0.5*K^2) / T^2 (T: max_target_length, K: index_topk)
433+ multiplier = calculate_indexer_mask_ratio (config .index_topk , config .max_target_length )
434+ attention_flops = (
435+ 2
436+ * batch_len
437+ * config .max_target_length
438+ * config .num_query_heads
439+ * (qk_head_dim_sum + config .v_head_dim )
440+ * multiplier
441+ )
442+ attention_flops += indexer_scoring_flops
443+ else :
444+ # standard MLA & max_target_length <= index_topk in sparse indexer
445+ # in both cases, the indexer is bypassed as the causal mask remains the efficient representation
446+ attention_flops = (
447+ 2 * batch_len * config .max_target_length * config .num_query_heads * (qk_head_dim_sum + config .v_head_dim )
448+ )
449+ attention_flops = attention_flops / 2
352450 projection_flops = 2 * batch_len * config .emb_dim * config .num_query_heads * config .v_head_dim
353451 return qkv_flops , attention_flops , projection_flops
354452
@@ -548,7 +646,7 @@ def calculate_tflops_training_per_device(config, log=True):
548646
549647 # Attention flops
550648 if config .attention_type == "mla" :
551- qkv_flops , noncausal_attention_flops , projection_flops = calculate_mla_tflops_per_device (config )
649+ qkv_flops , causal_attention_flops , projection_flops = calculate_mla_tflops_per_device (config )
552650 else :
553651 qkv_flops = (
554652 2
@@ -570,11 +668,11 @@ def calculate_tflops_training_per_device(config, log=True):
570668 * config .head_dim
571669 )
572670
573- # Divide attention flops by 2 due to causal mask
574- # References:
575- # NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
576- # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
577- causal_attention_flops = noncausal_attention_flops / 2
671+ # Divide attention flops by 2 due to causal mask
672+ # References:
673+ # NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362
674+ # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
675+ causal_attention_flops = noncausal_attention_flops / 2
578676
579677 # Embedding flops
580678 embedding_flops = 2 * config .per_device_batch_size * config .max_target_length * config .emb_dim * config .vocab_size
0 commit comments