Skip to content

Commit a2a238f

Browse files
committed
Integrate Engram into custom model
1 parent f3d9f5c commit a2a238f

9 files changed

Lines changed: 211 additions & 35 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ out_proj: 'remat'
315315
mla_q: 'remat'
316316
mla_kv: 'remat'
317317
attention_out: 'remat'
318+
engram: 'remat'
318319

319320
optimizer_memory_host_offload: False
320321
parameter_memory_host_offload: False
@@ -1102,3 +1103,20 @@ force_q_layout: false
11021103
mhc_expansion_rate: 1
11031104
# The number of iterations for the Sinkhorn-Knopp algorithm.
11041105
sinkhorn_iterations: 20
1106+
1107+
################################## DeepSeek Engram ##################################
1108+
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.
1109+
# Example: [1, 4] attaches to the 2nd and 5th layer.
1110+
engram_layers: []
1111+
# The max 'n' in N-gram. Example: n=3 means it covers both 2-grams and 3-grams.
1112+
engram_max_ngram_size: 3
1113+
# Number of heads dedicated to the Engram.
1114+
engram_num_heads: 8
1115+
# Head dimension for heads.
1116+
engram_head_dim: 1280
1117+
# List of minimum head vocab sizes for each n-gram order.
1118+
engram_vocab_bases: []
1119+
# Temporal window size for Engram convolution.
1120+
engram_kernel_size: 4
1121+
# The seed for Engram hash mapping.
1122+
engram_seed: 0

src/maxtext/configs/models/deepseek-custom.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,11 @@ index_topk: 256 # Reduced from 2048
5959
# Hyper-connections: mHC enabled
6060
mhc_expansion_rate: 4
6161
sinkhorn_iterations: 20
62+
# Engram
63+
engram_layers: [1, 4]
64+
engram_num_heads: 8
65+
engram_head_dim: 512
66+
engram_vocab_bases: [226240, 226240]
67+
engram_max_ngram_size: 3
68+
engram_kernel_size: 4
69+
engram_seed: 0

src/maxtext/configs/types.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,7 @@ class RematAndOffload(BaseModel):
897897
RematLocation.REMAT,
898898
description="Remat policy for the attention output.",
899899
)
900+
engram: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the engram output.")
900901

901902
optimizer_memory_host_offload: bool = Field(False, description="Offload optimizer state to host memory.")
902903
parameter_memory_host_offload: bool = Field(False, description="Offload parameters to host memory.")
@@ -1630,6 +1631,23 @@ class SpecialTokens(BaseModel):
16301631
solution_end_token: str = Field("</answer>", description="Token to mark the end of a solution section.")
16311632

16321633

1634+
class Engram(BaseModel):
1635+
"""Configuration for DeepSeek Engram (https://www.arxiv.org/pdf/2601.07372)."""
1636+
1637+
engram_layers: list[int] = Field(
1638+
default_factory=list,
1639+
description="Indices of transformer layers where Engram are integrated.",
1640+
)
1641+
engram_num_heads: int = Field(8, description="Number of heads dedicated to the Engram.")
1642+
engram_head_dim: int = Field(1280, description="Head dimension for heads.")
1643+
engram_vocab_bases: list[int] = Field(
1644+
default_factory=list, description="List of minimum head vocab sizes for each n-gram order."
1645+
)
1646+
engram_max_ngram_size: int = Field(3, description="The max 'n' in N-gram.")
1647+
engram_kernel_size: int = Field(4, description="Temporal window size for Engram convolution.")
1648+
engram_seed: int = Field(0, description="The seed for Engram hash mapping.")
1649+
1650+
16331651
class DerivedValues(BaseModel):
16341652
"""Holds all fields that are derived from other config values for perfect legacy compatibility."""
16351653

@@ -1782,6 +1800,7 @@ class MaxTextConfig(
17821800
Quantization,
17831801
# Core Model Architecture
17841802
ModelArchitecture,
1803+
Engram,
17851804
MTP,
17861805
Logits,
17871806
# Attention Mechanisms
@@ -2262,6 +2281,18 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22622281
and self.gradient_accumulation_steps > 1
22632282
):
22642283
raise ValueError("FP8 quantization is not compatible with gradient accumulation.")
2284+
if self.engram_layers:
2285+
if not self.hf_access_token or not self.tokenizer_path:
2286+
raise ValueError(
2287+
"Engram requires both 'hf_access_token' and 'tokenizer_path' " "to load the Hugging Face tokenizer."
2288+
)
2289+
if self.scan_layers:
2290+
raise NotImplementedError("Currently Engram only supports unscanned version. Please set scan_layers=False.")
2291+
if len(self.engram_vocab_bases) != (self.engram_max_ngram_size - 1):
2292+
raise ValueError(
2293+
f"Engram vocab size mismatch: expected {self.engram_max_ngram_size - 1} (max_ngram_size - 1), "
2294+
f"but got {self.engram_vocab_bases}."
2295+
)
22652296
if self.num_experts > 1:
22662297
is_fully_moe = (
22672298
self.interleave_moe_layer_step == 1

src/maxtext/layers/decoders.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,11 +916,19 @@ def __call__(
916916
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
917917
num_layers_list = [cfg.first_num_dense_layers, num_moe_layers]
918918
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
919+
global_layer_idx_offset = 0
919920
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
920921
for index in range(num_layers):
922+
global_layer_idx = global_layer_idx_offset + index
921923
kv_cache = kv_caches[index] if kv_caches is not None else None
924+
input_tokens = decoder_input_tokens if cfg.engram_layers else None
922925
y, kv_cache = layer(
923-
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
926+
config=cfg,
927+
mesh=mesh,
928+
name=f"{layer_prefix}_{index}",
929+
quant=self.quant,
930+
model_mode=self.model_mode,
931+
layer_idx=global_layer_idx,
924932
)(
925933
y,
926934
decoder_segment_ids,
@@ -932,9 +940,11 @@ def __call__(
932940
slot=slot,
933941
kv_cache=kv_cache,
934942
attention_metadata=attention_metadata,
943+
decoder_input_tokens=input_tokens,
935944
)
936945
if kv_caches is not None and kv_cache is not None:
937946
kv_caches[index] = kv_cache
947+
global_layer_idx_offset += num_layers
938948
else:
939949
for lyr in range(cfg.num_decoder_layers):
940950
RemattedBlockLayer = RemattedBlockLayers[0]

src/maxtext/layers/engram.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sympy import isprime
3636
from tokenizers import Regex, normalizers
3737

38+
3839
class CompressedTokenizer:
3940
"""
4041
A canonicalizing wrapper that reduces vocabulary sparsity for n-gram lookup.
@@ -50,7 +51,8 @@ class CompressedTokenizer:
5051

5152
def __init__(self, tokenizer: HFTokenizer):
5253
normalizer = self._build_normalizer()
53-
self.lookup_table, self.num_new_token = self._build_lookup_table(tokenizer, normalizer)
54+
self.lookup_table_np, self.num_new_token = self._build_lookup_table(tokenizer, normalizer)
55+
self.lookup_table = jnp.array(self.lookup_table_np, dtype=jnp.int64)
5456

5557
def __len__(self) -> int:
5658
return self.num_new_token
@@ -118,19 +120,18 @@ def _build_lookup_table(self, tokenizer: HFTokenizer, normalizer: normalizers.Se
118120

119121
return old2new, len(key2new)
120122

121-
def __call__(self, input_ids) -> np.ndarray:
123+
def __call__(self, input_ids) -> Array:
122124
"""
123125
Maps original token IDs to compressed IDs.
124126
"""
125-
input_ids = np.asarray(input_ids, dtype=np.int64)
127+
input_ids = jnp.asarray(input_ids, dtype=jnp.int64)
126128

127-
# Identify valid tokens (ignore padding/masks usually marked with negative IDs)
128-
valid_mask = input_ids >= 0
129-
valid_ids = input_ids[valid_mask]
129+
# Map negative IDs to 0 for lookup, then mask output back.
130+
safe_ids = jnp.where(input_ids < 0, 0, input_ids)
131+
mapped_ids = self.lookup_table[safe_ids]
130132

131-
# Vectorized lookup: O(1) per token
132-
output_ids = input_ids.copy()
133-
output_ids[valid_mask] = self.lookup_table[valid_ids]
133+
# Restore negative IDs (padding)
134+
output_ids = jnp.where(input_ids < 0, input_ids, mapped_ids)
134135
return output_ids
135136

136137

@@ -177,11 +178,16 @@ def __init__(
177178
# Initialize compressed tokenizer
178179
self.compressed_tokenizer = CompressedTokenizer(tokenizer)
179180
self.tokenizer_vocab_size = len(self.compressed_tokenizer)
180-
if pad_id is not None:
181-
self.pad_id = int(self.compressed_tokenizer.lookup_table[pad_id])
181+
if pad_id is None:
182+
raise ValueError("The `pad_id` must be provided and cannot be None.")
183+
# Pre-calculate pad_id on CPU using numpy array to avoid ConcretizationTypeError
184+
self.pad_id = int(self.compressed_tokenizer.lookup_table_np[pad_id])
182185

183186
# Pre-calculate odd multipliers for hashing: {layer_id: multipliers}
184-
self.layer_multipliers = self._calculate_multipliers_across_layers(seed)
187+
# Store as JAX arrays
188+
self.layer_multipliers = {
189+
k: jnp.array(v, dtype=jnp.int64) for k, v in self._calculate_multipliers_across_layers(seed).items()
190+
}
185191

186192
# Pre-calculate unique prime vocab sizes for every head
187193
# Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]}
@@ -254,7 +260,7 @@ def get_vocab_sizes(self, layer_id: int) -> List[int]:
254260
"""
255261
return [head_size for ngram_size in self.vocab_size_across_layers[layer_id] for head_size in ngram_size]
256262

257-
def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.ndarray:
263+
def _get_ngram_hashes(self, compressed_ids: Array, layer_id: int) -> Array:
258264
"""
259265
Computes hash indices for all n-grams in the input batch.
260266
@@ -265,22 +271,21 @@ def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.nda
265271
Returns:
266272
hash_ids: [B, S, H_total] where H_total = H * num_ngram_orders
267273
"""
268-
x = np.asarray(compressed_ids, dtype=np.int64)
269-
B, S = x.shape
274+
x = jnp.asarray(compressed_ids, dtype=jnp.int64)
275+
B, _ = x.shape
270276

271277
# 1. Create Sliding Windows via Shifting
272278
shifted_inputs = []
273279
for k in range(self.max_ngram_size):
274280
if k == 0:
275-
# e.g., k=0, [The, cat, sat]
276281
shifted_inputs.append(x)
277282
else:
278283
# Pre-allocate full array with PAD_ID
279-
shifted_x = np.full((B, S), self.pad_id, dtype=np.int64)
284+
padding = jnp.full((B, k), self.pad_id, dtype=jnp.int64)
280285
# Fast memory copy, slicing and assignment
281286
# e.g., k=1, [PAD, The, cat]
282287
# k=2, [PAD, PAD, The]
283-
shifted_x[:, k:] = x[:, :-k]
288+
shifted_x = jnp.concatenate([padding, x[:, :-k]], axis=1)
284289
shifted_inputs.append(shifted_x)
285290

286291
# 2. Retrieve layer-specific hash multipliers
@@ -299,21 +304,21 @@ def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.nda
299304

300305
for n in range(2, self.max_ngram_size + 1):
301306
# Update hash with next history token
302-
ngram_hash = np.bitwise_xor(ngram_hash, shifted_inputs[n - 1] * multipliers[n - 1])
307+
ngram_hash = jnp.bitwise_xor(ngram_hash, shifted_inputs[n - 1] * multipliers[n - 1])
303308

304309
# Retrieve prime vocab sizes for all heads of this n-gram order
305310
vocab_sizes_for_this_gram = vocab_sizes[n - 2]
306-
mods = np.array(vocab_sizes_for_this_gram, dtype=np.int64)
311+
mods = jnp.array(vocab_sizes_for_this_gram, dtype=jnp.int64)
307312

308313
# Broadcast Modulo: Map hash to valid table indices
309314
# [B, S, 1] % [H] -> [B, S, H]
310315
head_hashes = ngram_hash[..., None] % mods
311316
all_hashes.append(head_hashes)
312317

313318
# Concatenate all heads: [B, S, H_total] where H_total = H * num_ngram_orders
314-
return np.concatenate(all_hashes, axis=2)
319+
return jnp.concatenate(all_hashes, axis=2)
315320

316-
def __call__(self, input_ids) -> dict[int, np.ndarray]:
321+
def __call__(self, input_ids) -> dict[int, Array]:
317322
# input_ids from standard tokenizer
318323
compressed_ids = self.compressed_tokenizer(input_ids)
319324
hash_ids_for_all_layers = {}
@@ -323,6 +328,13 @@ def __call__(self, input_ids) -> dict[int, np.ndarray]:
323328
return hash_ids_for_all_layers
324329

325330

331+
class StaticWrapper:
332+
"""Wrapper to prevent nnx from treating the value as a variable."""
333+
334+
def __init__(self, val):
335+
self.val = val
336+
337+
326338
class MultiHeadEmbedding(nnx.Module):
327339
"""
328340
A flattened table representation for multi-head embedding spaces across n-gram orders.
@@ -350,7 +362,7 @@ def __init__(
350362
# Compute starting index for each head's segment in the flattened table.
351363
# Offsets serve as the "base address" for each head.
352364
offsets = np.cumsum([0] + vocab_sizes[:-1]) # prefix sum
353-
self.offsets = jnp.array(offsets, dtype=jnp.int32)
365+
self.offsets = StaticWrapper(np.array(offsets, dtype=np.int64))
354366

355367
# The total embedding size is the sum of all individual head vocabularies.
356368
self.embedding = Embed(num_embeddings=sum(vocab_sizes), num_features=head_dim, config=config, mesh=mesh, rngs=rngs)
@@ -368,7 +380,7 @@ def __call__(self, input_ids: Array, model_mode: str = MODEL_MODE_TRAIN) -> Arra
368380
"""
369381
# Broadcasting Add: [B, S, H] + [H] -> [B, S, H]
370382
# Shifts local indices (0..Prime-1) to global table positions.
371-
shifted_ids = input_ids + self.offsets
383+
shifted_ids = input_ids + self.offsets.val
372384

373385
# Embedding lookup: [B, S, H_total] -> [B, S, H_total, D_head]
374386
return self.embedding(shifted_ids, model_mode=model_mode)

0 commit comments

Comments
 (0)