Skip to content

Commit d14f70d

Browse files
Merge pull request #3255 from AI-Hypercomputer:engram_scan_clean
PiperOrigin-RevId: 879272037
2 parents 13e6f40 + 88afe99 commit d14f70d

4 files changed

Lines changed: 315 additions & 31 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,8 +2348,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23482348
raise ValueError(
23492349
"Engram requires both 'hf_access_token' and 'tokenizer_path' " "to load the Hugging Face tokenizer."
23502350
)
2351-
if self.scan_layers:
2352-
raise NotImplementedError("Currently Engram only supports unscanned version. Please set scan_layers=False.")
23532351
if len(self.engram_vocab_bases) != (self.engram_max_ngram_size - 1):
23542352
raise ValueError(
23552353
f"Engram vocab size mismatch: expected {self.engram_max_ngram_size - 1} (max_ngram_size - 1), "

src/maxtext/layers/decoders.py

Lines changed: 139 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -854,45 +854,85 @@ def __call__(
854854
"slot": slot,
855855
}
856856
dense_layer = RemattedBlockLayers[0]
857-
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
858-
y, _ = self.scan_decoder_layers(
859-
cfg,
860-
dense_layer,
861-
cfg.first_num_dense_layers,
862-
"dense_layers",
863-
mesh,
864-
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
865-
model_mode=model_mode,
866-
)(y, *broadcast_args)
867857
moe_layer = RemattedBlockLayers[1]
868-
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
869-
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
858+
if cfg.engram_layers:
859+
original_dense_call = dense_layer.__call__
860+
original_moe_call = moe_layer.__call__
861+
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
862+
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
863+
864+
common_kwargs = {
865+
"dense_layer": dense_layer,
866+
"moe_layer": moe_layer,
867+
"original_dense_call": original_dense_call,
868+
"original_moe_call": original_moe_call,
869+
"layer_call_kwargs": layer_call_kwargs,
870+
"decoder_segment_ids": decoder_segment_ids,
871+
"decoder_positions": decoder_positions,
872+
"deterministic": deterministic,
873+
"model_mode": model_mode,
874+
"decoder_input_tokens": decoder_input_tokens,
875+
"broadcast_args": broadcast_args,
876+
}
870877

871-
# If batch-split schedule is used and initialization is complete,
872-
# as detected by immutable params, use deepseek_batchsplit custom
873-
# scan with initialized parameters.
874-
if cfg.use_batch_split_schedule and not self.is_mutable_collection("params"):
875-
y = deepseek_batchsplit.scan_batch_split_layers(
878+
# Apply Dense Layers
879+
y = self._apply_interleaved_scanned_layers(
876880
y,
877-
self.variables["params"]["moe_layers"],
878-
decoder_positions,
879-
decoder_segment_ids,
880-
model_mode=model_mode,
881-
mesh=mesh,
882-
quant=self.quant,
883-
cfg=cfg,
884-
policy=policy,
881+
layer_type="dense",
882+
start_idx=0,
883+
end_idx=cfg.first_num_dense_layers,
884+
engram_indices=cfg.engram_layers,
885+
**common_kwargs,
886+
)
887+
888+
# Apply MoE Layers
889+
y = self._apply_interleaved_scanned_layers(
890+
y,
891+
layer_type="moe",
892+
start_idx=cfg.first_num_dense_layers,
893+
end_idx=cfg.num_decoder_layers,
894+
engram_indices=cfg.engram_layers,
895+
**common_kwargs,
885896
)
886897
else:
898+
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
887899
y, _ = self.scan_decoder_layers(
888900
cfg,
889-
moe_layer,
890-
num_moe_layers,
891-
"moe_layers",
901+
dense_layer,
902+
cfg.first_num_dense_layers,
903+
"dense_layers",
892904
mesh,
893905
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
894906
model_mode=model_mode,
895907
)(y, *broadcast_args)
908+
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
909+
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
910+
911+
# If batch-split schedule is used and initialization is complete,
912+
# as detected by immutable params, use deepseek_batchsplit custom
913+
# scan with initialized parameters.
914+
if cfg.use_batch_split_schedule and not self.is_mutable_collection("params"):
915+
y = deepseek_batchsplit.scan_batch_split_layers(
916+
y,
917+
self.variables["params"]["moe_layers"],
918+
decoder_positions,
919+
decoder_segment_ids,
920+
model_mode=model_mode,
921+
mesh=mesh,
922+
quant=self.quant,
923+
cfg=cfg,
924+
policy=policy,
925+
)
926+
else:
927+
y, _ = self.scan_decoder_layers(
928+
cfg,
929+
moe_layer,
930+
num_moe_layers,
931+
"moe_layers",
932+
mesh,
933+
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
934+
model_mode=model_mode,
935+
)(y, *broadcast_args)
896936
elif cfg.decoder_block == DecoderBlockType.GEMMA3:
897937
y = self._apply_gemma3_scanned_blocks(
898938
y,
@@ -1118,3 +1158,74 @@ def _apply_gemma3_scanned_blocks(
11181158
**layer_call_kwargs,
11191159
)
11201160
return y
1161+
1162+
# TODO(b/490118813): Relocate the following functions to their designated directories
1163+
# once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer()
1164+
# _apply_scanned_chunk() and _apply_interleaved_scanned_layers().
1165+
def _find_next_boundary(self, current_idx, end_idx, engram_indices):
1166+
"""Finds the next index boundary, either the next Engram layer index or the overall end index."""
1167+
next_engrams = [l for l in engram_indices if l > current_idx]
1168+
if next_engrams:
1169+
return min(end_idx, *next_engrams)
1170+
return end_idx
1171+
1172+
def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
1173+
"""Applies a single, unscanned Engram layer."""
1174+
layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"]
1175+
layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1176+
original_call = kwargs["original_dense_call"] if layer_type == "dense" else kwargs["original_moe_call"]
1177+
layer_call_kwargs = kwargs["layer_call_kwargs"]
1178+
1179+
layer.__call__ = original_call
1180+
y, _ = layer(
1181+
config=self.config,
1182+
mesh=self.mesh,
1183+
name=f"{layer_prefix}_engram_{current_idx}",
1184+
quant=self.quant,
1185+
model_mode=self.model_mode,
1186+
layer_idx=current_idx,
1187+
)(
1188+
y,
1189+
kwargs["decoder_segment_ids"],
1190+
kwargs["decoder_positions"],
1191+
kwargs["deterministic"],
1192+
kwargs["model_mode"],
1193+
decoder_input_tokens=kwargs["decoder_input_tokens"],
1194+
**layer_call_kwargs,
1195+
)
1196+
layer.__call__ = functools.partial(original_call, **layer_call_kwargs)
1197+
return y
1198+
1199+
def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_type, **kwargs):
1200+
"""Applies a contiguous chunk of layers using the scan operation."""
1201+
layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"]
1202+
layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1203+
broadcast_args = kwargs["broadcast_args"]
1204+
scan_length = next_boundary - current_idx
1205+
1206+
if scan_length > 0:
1207+
y, _ = self.scan_decoder_layers(
1208+
self.config,
1209+
layer,
1210+
scan_length,
1211+
f"{layer_prefix}_{current_idx}_{next_boundary - 1}",
1212+
self.mesh,
1213+
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
1214+
model_mode=kwargs["model_mode"],
1215+
)(y, *broadcast_args)
1216+
return y
1217+
1218+
def _apply_interleaved_scanned_layers(self, y, layer_type, start_idx, end_idx, engram_indices, **kwargs):
1219+
"""Applies a mix of scanned standard layers and unscanned Engram layers."""
1220+
current_idx = start_idx
1221+
while current_idx < end_idx:
1222+
if current_idx in engram_indices:
1223+
# Handle individual unscanned Engram layer
1224+
y = self._apply_single_engram_layer(y, current_idx, layer_type, **kwargs)
1225+
current_idx += 1
1226+
else:
1227+
# Find next boundary and scan the chunk
1228+
next_boundary = self._find_next_boundary(current_idx, end_idx, engram_indices)
1229+
y = self._apply_scanned_chunk(y, current_idx, next_boundary, layer_type, **kwargs)
1230+
current_idx = next_boundary
1231+
return y
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright 2023-2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for DeepSeek Engram across scanned decoder layers."""
16+
17+
import gc
18+
import os
19+
import unittest
20+
from unittest.mock import patch
21+
22+
import jax
23+
import jax.numpy as jnp
24+
from jax.sharding import Mesh
25+
26+
from maxtext.configs import pyconfig
27+
from maxtext.utils.globals import MAXTEXT_PKG_DIR
28+
from maxtext.common.common_types import MODEL_MODE_TRAIN
29+
from maxtext.layers.decoders import Decoder
30+
from maxtext.utils import maxtext_utils
31+
import pytest
32+
33+
34+
class DummyEmbedding:
35+
"""Dummy embedding layer for testing."""
36+
37+
def __init__(self, emb_dim: int):
38+
self.emb_dim = emb_dim
39+
40+
def __call__(self, x, model_mode):
41+
return jnp.ones((x.shape[0], x.shape[1], self.emb_dim))
42+
43+
44+
class TestDeepSeekScanEngram(unittest.TestCase):
45+
"""Test DeepSeek decoder block with scan_layers=True and engram_layers."""
46+
47+
_COMMON_CONFIG = [
48+
"run_name=test_deepseek_scan_engram",
49+
"model_name=deepseek-custom",
50+
"override_model_config=True",
51+
"decoder_block=deepseek",
52+
"scan_layers=True",
53+
"first_num_dense_layers=5",
54+
"base_num_decoder_layers=10",
55+
"num_decoder_layers=10",
56+
"mhc_expansion_rate=4",
57+
"attention=dot_product",
58+
"per_device_batch_size=2",
59+
"max_target_length=8",
60+
"max_prefill_predict_length=8",
61+
"enable_checkpointing=False",
62+
"engram_num_heads=1",
63+
"engram_head_dim=32",
64+
"engram_vocab_bases=[226240,226240]",
65+
"engram_max_ngram_size=3",
66+
"engram_kernel_size=4",
67+
"hf_access_token=dummy",
68+
"tokenizer_path=dummy",
69+
]
70+
71+
def _test_engram_pattern(self, mock_from_pretrained, engram_layers_str, expected_keys):
72+
"""Helper method to test different engram layer patterns."""
73+
74+
# Setup mock tokenizer
75+
class MockTokenizer:
76+
"""Mock tokenizer for testing."""
77+
78+
pad_token_id = 0
79+
80+
def __len__(self):
81+
return 1000
82+
83+
def __call__(self, x):
84+
return jnp.ones_like(x)
85+
86+
def convert_ids_to_tokens(self, *args, **kwargs):
87+
return "a"
88+
89+
def decode(self, *args, **kwargs):
90+
return "a"
91+
92+
def batch_decode(self, token_ids, *args, **kwargs):
93+
return ["a" for _ in token_ids]
94+
95+
mock_from_pretrained.return_value = MockTokenizer()
96+
97+
config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")
98+
config = pyconfig.initialize([None, config_path] + self._COMMON_CONFIG + [f"engram_layers=[{engram_layers_str}]"])
99+
100+
devices_array = maxtext_utils.create_device_mesh(config)
101+
mesh = Mesh(devices_array, config.mesh_axes)
102+
103+
decoder = Decoder(
104+
config=config,
105+
mesh=mesh,
106+
model_mode=MODEL_MODE_TRAIN,
107+
)
108+
109+
batch_size = config.global_batch_size_to_load
110+
seq_len = config.max_target_length
111+
112+
decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
113+
decoder_positions = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
114+
decoder_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
115+
116+
shared_embedding = DummyEmbedding(emb_dim=config.emb_dim)
117+
118+
with mesh:
119+
variables = decoder.init(
120+
{"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1), "aqt": jax.random.PRNGKey(2)},
121+
shared_embedding=shared_embedding,
122+
decoder_input_tokens=decoder_input_tokens,
123+
decoder_positions=decoder_positions,
124+
decoder_segment_ids=decoder_segment_ids,
125+
deterministic=True,
126+
model_mode=MODEL_MODE_TRAIN,
127+
)
128+
129+
self.assertIn("params", variables)
130+
params = variables["params"]
131+
for key in expected_keys:
132+
self.assertIn(key, params)
133+
134+
del variables
135+
del params
136+
del decoder
137+
jax.clear_caches()
138+
gc.collect()
139+
140+
@pytest.mark.tpu_only
141+
@patch("transformers.AutoTokenizer.from_pretrained")
142+
def test_decoder_init_engram_2_8(self, mock_from_pretrained):
143+
"""Test engram layers at indices 2 and 8."""
144+
self._test_engram_pattern(
145+
mock_from_pretrained,
146+
"2,8",
147+
[
148+
"dense_layers_0_1",
149+
"dense_layers_engram_2",
150+
"dense_layers_3_4",
151+
"moe_layers_5_7",
152+
"moe_layers_engram_8",
153+
"moe_layers_9_9",
154+
],
155+
)
156+
157+
@pytest.mark.tpu_only
158+
@patch("transformers.AutoTokenizer.from_pretrained")
159+
def test_decoder_init_engram_0_5(self, mock_from_pretrained):
160+
"""Test engram layers at indices 0 and 5 - first engram layer of block."""
161+
self._test_engram_pattern(
162+
mock_from_pretrained,
163+
"0,5",
164+
["dense_layers_engram_0", "dense_layers_1_4", "moe_layers_engram_5", "moe_layers_6_9"],
165+
)
166+
167+
@pytest.mark.tpu_only
168+
@patch("transformers.AutoTokenizer.from_pretrained")
169+
def test_decoder_init_engram_4_9(self, mock_from_pretrained):
170+
"""Test engram layers at indices 4 and 9 - last engram layer of block."""
171+
self._test_engram_pattern(
172+
mock_from_pretrained,
173+
"4,9",
174+
["dense_layers_0_3", "dense_layers_engram_4", "moe_layers_5_8", "moe_layers_engram_9"],
175+
)

tests/unit/train_compile_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def test_engram_integration(self):
830830
"compile_topology_num_slices=1",
831831
"model_name=deepseek-custom",
832832
"per_device_batch_size=4",
833-
"scan_layers=False", # TODO(ranran): update to scan_layers=True after support
833+
"scan_layers=True",
834834
"max_target_length=1024",
835835
"attention=flash",
836836
"use_tokamax_splash=True",

0 commit comments

Comments
 (0)