Skip to content

Commit 20b297f

Browse files
Merge pull request #3223 from AI-Hypercomputer:engram_flops_clean
PiperOrigin-RevId: 877595416
2 parents 23a82de + ccf89db commit 20b297f

3 files changed

Lines changed: 129 additions & 19 deletions

File tree

src/maxtext/layers/engram.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def __call__(self, x: Array) -> Array:
469469
B: Batch size
470470
S: Sequence length (temporal dimension)
471471
G: Number of branches (mhc_expansion_rate)
472-
D: Hidden size (base_emb_dim)
472+
D: Hidden size (emb_dim)
473473
"""
474474
B, S, G, D = x.shape
475475

@@ -557,7 +557,7 @@ def __init__(
557557
# retrieved n-gram memory -> Key, from D_en to [G, D]
558558
self.key_proj = DenseGeneral(
559559
in_features_shape=self.engram_dim,
560-
out_features_shape=(mhc_expansion_rate, config.base_emb_dim),
560+
out_features_shape=(mhc_expansion_rate, config.emb_dim),
561561
axis=-1,
562562
kernel_init=self.kernel_init,
563563
kernel_axes=("engram_dim", "mhc", "embed"),
@@ -578,7 +578,7 @@ def __init__(
578578
@nnx.vmap(in_axes=0, out_axes=0)
579579
def create_norms(rngs):
580580
return RMSNorm(
581-
num_features=config.base_emb_dim,
581+
num_features=config.emb_dim,
582582
dtype=config.dtype,
583583
weight_dtype=config.weight_dtype,
584584
epsilon=config.normalization_layer_epsilon,
@@ -594,7 +594,7 @@ def create_norms(rngs):
594594
# Value Projection (Shared): Retrieved memory -> Value
595595
self.value_proj = DenseGeneral(
596596
in_features_shape=self.engram_dim,
597-
out_features_shape=config.base_emb_dim,
597+
out_features_shape=config.emb_dim,
598598
axis=-1,
599599
kernel_init=self.kernel_init,
600600
kernel_axes=("engram_dim", "embed"),
@@ -611,7 +611,7 @@ def create_norms(rngs):
611611
# Applies depthwise causal convolution to smooth the retrieved memory over time.
612612
self.short_conv = ShortConv(
613613
config=config,
614-
hidden_size=config.base_emb_dim,
614+
hidden_size=config.emb_dim,
615615
kernel_size=self.conv_kernel_size,
616616
dilation=self.max_ngram_size,
617617
mhc_expansion_rate=mhc_expansion_rate,
@@ -635,7 +635,7 @@ def __call__(self, hidden_states: Array, hash_input_ids: Array) -> Array:
635635
S: Sequence Length
636636
G: mhc_expansion_rate, Number of Branches
637637
H_total: Total number of heads across n-grams. num_head * num_ngrams
638-
D: base_emb_dim
638+
D: emb_dim
639639
D_head: Dimension of a single head embedding
640640
D_en: Dimension of flattened embedding across heads and n-grams
641641
"""

src/maxtext/utils/maxtext_utils.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,38 @@ def calculate_llama4_vision_layers_tflops_per_device(config):
669669
return total_tflops, learnable_weight_tflops, total_attn_tflops
670670

671671

672+
def calculate_engram_tflops(config):
673+
"""Calculate engram TFLOPs per device."""
674+
B = config.per_device_batch_size
675+
S = config.max_target_length
676+
G = config.mhc_expansion_rate # Multi-manifold branches
677+
D = config.emb_dim # Hidden dimension
678+
k = config.engram_kernel_size # Conv window
679+
680+
num_ngram_orders = config.engram_max_ngram_size - 1
681+
engram_dim = config.engram_num_heads * config.engram_head_dim * num_ngram_orders
682+
683+
# 1. Key Projection
684+
key_flops = 2 * (B * S) * engram_dim * (G * D)
685+
# 2. Value Projection
686+
value_flops = 2 * (B * S) * engram_dim * D
687+
# 3. QK Attention
688+
attention_flops = 2 * (B * S) * G * D
689+
# 4. Short Convolution
690+
# Standard flops as 2 * kernel_size * input_channels * output_channels / feature_group_count
691+
# In Engram, the feature_group_count = input_channels = output_channels
692+
# Unlike global attention, convolution work is constant per token (not O(S^2)),
693+
# so we do not apply the 0.5 causal scaling factor.
694+
total_channels = G * D
695+
conv_flops = 2 * (B * S) * k * total_channels
696+
697+
num_layers = len(config.engram_layers)
698+
# account for both the forward (1x) and backward (2x) passes
699+
learnable_tflops = num_layers * (key_flops + value_flops + conv_flops) * 3 / 1e12
700+
attention_tflops = num_layers * attention_flops * 3 / 1e12
701+
return learnable_tflops, attention_tflops
702+
703+
672704
def calculate_vision_encoder_tflops(config):
673705
"""Calculate vision encoder TFLOPs per prefill step per device."""
674706
if config.model_name.startswith("gemma3"):
@@ -786,18 +818,11 @@ def calculate_tflops_training_per_device(config, log=True):
786818
)
787819
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
788820

789-
learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps
790-
attention_tflops = attention_tflops * config.gradient_accumulation_steps
791-
792-
# DPO includes one additional forward pass per gradient accumulation step
793-
if config.use_dpo:
794-
reference_model_tflops = learnable_weight_tflops / 3 # additional forward pass
795-
reference_model_attention_tflops = attention_tflops / 3
796-
attention_tflops = attention_tflops + reference_model_attention_tflops
797-
else:
798-
reference_model_tflops = 0
799-
800-
total_tflops = learnable_weight_tflops + attention_tflops + reference_model_tflops
821+
# Engram flops
822+
if config.engram_layers:
823+
engram_learnable_tflops, engram_attention_tflops = calculate_engram_tflops(config)
824+
learnable_weight_tflops += engram_learnable_tflops
825+
attention_tflops += engram_attention_tflops
801826

802827
if config.use_multimodal:
803828
# Add vision layers TFLOPs for multimodal models
@@ -810,10 +835,22 @@ def calculate_tflops_training_per_device(config, log=True):
810835
f"and {100 * mm_attention_tflops/mm_total_tflops:.2f}% attention flops;\n",
811836
f"learnable weight {mm_learnable_weight_tflops:.2f} TFLOPs, attention {mm_attention_tflops:.2f} TFLOPs",
812837
)
813-
total_tflops += mm_total_tflops
814838
learnable_weight_tflops += mm_learnable_weight_tflops
815839
attention_tflops += mm_attention_tflops
816840

841+
learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps
842+
attention_tflops = attention_tflops * config.gradient_accumulation_steps
843+
844+
# DPO includes one additional forward pass per gradient accumulation step
845+
if config.use_dpo:
846+
reference_model_tflops = learnable_weight_tflops / 3 # additional forward pass
847+
reference_model_attention_tflops = attention_tflops / 3
848+
attention_tflops = attention_tflops + reference_model_attention_tflops
849+
else:
850+
reference_model_tflops = 0
851+
852+
total_tflops = learnable_weight_tflops + attention_tflops + reference_model_tflops
853+
817854
if log:
818855
print(
819856
"Per train step:\n",

tests/unit/flop_calculation_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,3 +487,76 @@ def test_deepseek32_671b_flops(self):
487487
)
488488
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
489489
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)
490+
491+
@pytest.mark.cpu_only
492+
def test_custom_engram_flops(self):
493+
"""Test model with Engram FLops calculation"""
494+
kwargs = {
495+
# Model bases
496+
"model_name": "deepseek2-16b",
497+
"override_model_config": True,
498+
# Core workload parameters
499+
"per_device_batch_size": 4,
500+
"max_target_length": 8192,
501+
"num_experts": 64,
502+
"num_experts_per_tok": 6,
503+
"shared_experts": 2,
504+
# Model dimensions
505+
"base_emb_dim": 2048,
506+
"base_num_query_heads": 16,
507+
"base_num_kv_heads": 16,
508+
"base_mlp_dim": 10944,
509+
"base_moe_mlp_dim": 1408,
510+
"base_num_decoder_layers": 27,
511+
"first_num_dense_layers": 1,
512+
"mlp_activations": ["silu", "linear"],
513+
"vocab_size": 102400,
514+
# MLA
515+
"q_lora_rank": 0,
516+
"kv_lora_rank": 512,
517+
"qk_nope_head_dim": 128,
518+
"qk_rope_head_dim": 64,
519+
"v_head_dim": 128,
520+
"skip_jax_distributed_system": True,
521+
# Engram
522+
"mhc_expansion_rate": 1,
523+
"engram_layers": [2, 15],
524+
"engram_num_heads": 8,
525+
"engram_head_dim": 1280,
526+
"engram_kernel_size": 4,
527+
"engram_max_ngram_size": 3,
528+
"engram_vocab_bases": [226240, 226240],
529+
"tokenizer_type": "huggingface",
530+
"tokenizer_path": "deepseek-ai/DeepSeek-V3.2",
531+
"hf_access_token": "fake",
532+
"scan_layers": False,
533+
}
534+
B = kwargs["per_device_batch_size"]
535+
S = kwargs["max_target_length"]
536+
G = kwargs["mhc_expansion_rate"]
537+
D = kwargs["base_emb_dim"]
538+
K = kwargs["engram_kernel_size"]
539+
H = kwargs["engram_num_heads"]
540+
H_D = kwargs["engram_head_dim"]
541+
L = len(kwargs["engram_layers"])
542+
N = kwargs["engram_max_ngram_size"]
543+
544+
attention_flops = self.compute_deepseek_attention_flops_per_device(kwargs)
545+
# deepseek2-16b has ~2.4B active parameters
546+
# https://arxiv.org/pdf/2405.04434
547+
golden_param_size = 2.4e9
548+
549+
# calculate Engram active params
550+
engram_dim = H * H_D * (N - 1)
551+
key_params = engram_dim * (G * D)
552+
value_params = engram_dim * D
553+
conv_params = K * (G * D)
554+
engram_active_params = L * (key_params + value_params + conv_params)
555+
golden_tflops = 6 * B * S * (golden_param_size + engram_active_params) / 1e12 + attention_flops
556+
557+
cfg = pyconfig.initialize(
558+
[None, get_test_config_path()],
559+
**kwargs,
560+
)
561+
calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg)
562+
self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops)

0 commit comments

Comments
 (0)