@@ -669,6 +669,38 @@ def calculate_llama4_vision_layers_tflops_per_device(config):
669669 return total_tflops , learnable_weight_tflops , total_attn_tflops
670670
671671
672+ def calculate_engram_tflops (config ):
673+ """Calculate engram TFLOPs per device."""
674+ B = config .per_device_batch_size
675+ S = config .max_target_length
676+ G = config .mhc_expansion_rate # Multi-manifold branches
677+ D = config .emb_dim # Hidden dimension
678+ k = config .engram_kernel_size # Conv window
679+
680+ num_ngram_orders = config .engram_max_ngram_size - 1
681+ engram_dim = config .engram_num_heads * config .engram_head_dim * num_ngram_orders
682+
683+ # 1. Key Projection
684+ key_flops = 2 * (B * S ) * engram_dim * (G * D )
685+ # 2. Value Projection
686+ value_flops = 2 * (B * S ) * engram_dim * D
687+ # 3. QK Attention
688+ attention_flops = 2 * (B * S ) * G * D
689+ # 4. Short Convolution
690+ # Standard flops as 2 * kernel_size * input_channels * output_channels / feature_group_count
691+ # In Engram, the feature_group_count = input_channels = output_channels
692+ # Unlike global attention, convolution work is constant per token (not O(S^2)),
693+ # so we do not apply the 0.5 causal scaling factor.
694+ total_channels = G * D
695+ conv_flops = 2 * (B * S ) * k * total_channels
696+
697+ num_layers = len (config .engram_layers )
698+ # account for both the forward (1x) and backward (2x) passes
699+ learnable_tflops = num_layers * (key_flops + value_flops + conv_flops ) * 3 / 1e12
700+ attention_tflops = num_layers * attention_flops * 3 / 1e12
701+ return learnable_tflops , attention_tflops
702+
703+
672704def calculate_vision_encoder_tflops (config ):
673705 """Calculate vision encoder TFLOPs per prefill step per device."""
674706 if config .model_name .startswith ("gemma3" ):
@@ -786,18 +818,11 @@ def calculate_tflops_training_per_device(config, log=True):
786818 )
787819 attention_tflops = causal_attention_flops * config .num_decoder_layers * 3 / 10 ** 12
788820
789- learnable_weight_tflops = learnable_weight_tflops * config .gradient_accumulation_steps
790- attention_tflops = attention_tflops * config .gradient_accumulation_steps
791-
792- # DPO includes one additional forward pass per gradient accumulation step
793- if config .use_dpo :
794- reference_model_tflops = learnable_weight_tflops / 3 # additional forward pass
795- reference_model_attention_tflops = attention_tflops / 3
796- attention_tflops = attention_tflops + reference_model_attention_tflops
797- else :
798- reference_model_tflops = 0
799-
800- total_tflops = learnable_weight_tflops + attention_tflops + reference_model_tflops
821+ # Engram flops
822+ if config .engram_layers :
823+ engram_learnable_tflops , engram_attention_tflops = calculate_engram_tflops (config )
824+ learnable_weight_tflops += engram_learnable_tflops
825+ attention_tflops += engram_attention_tflops
801826
802827 if config .use_multimodal :
803828 # Add vision layers TFLOPs for multimodal models
@@ -810,10 +835,22 @@ def calculate_tflops_training_per_device(config, log=True):
810835 f"and { 100 * mm_attention_tflops / mm_total_tflops :.2f} % attention flops;\n " ,
811836 f"learnable weight { mm_learnable_weight_tflops :.2f} TFLOPs, attention { mm_attention_tflops :.2f} TFLOPs" ,
812837 )
813- total_tflops += mm_total_tflops
814838 learnable_weight_tflops += mm_learnable_weight_tflops
815839 attention_tflops += mm_attention_tflops
816840
841+ learnable_weight_tflops = learnable_weight_tflops * config .gradient_accumulation_steps
842+ attention_tflops = attention_tflops * config .gradient_accumulation_steps
843+
844+ # DPO includes one additional forward pass per gradient accumulation step
845+ if config .use_dpo :
846+ reference_model_tflops = learnable_weight_tflops / 3 # additional forward pass
847+ reference_model_attention_tflops = attention_tflops / 3
848+ attention_tflops = attention_tflops + reference_model_attention_tflops
849+ else :
850+ reference_model_tflops = 0
851+
852+ total_tflops = learnable_weight_tflops + attention_tflops + reference_model_tflops
853+
817854 if log :
818855 print (
819856 "Per train step:\n " ,
0 commit comments