Skip to content

Commit bf174d6

Browse files
Merge pull request #3183 from AI-Hypercomputer:new_engram_integration
PiperOrigin-RevId: 875313089
2 parents 754a4d2 + a2a238f commit bf174d6

9 files changed

Lines changed: 217 additions & 40 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ out_proj: 'remat'
315315
mla_q: 'remat'
316316
mla_kv: 'remat'
317317
attention_out: 'remat'
318+
engram: 'remat'
318319

319320
optimizer_memory_host_offload: False
320321
parameter_memory_host_offload: False
@@ -1102,3 +1103,20 @@ force_q_layout: false
11021103
mhc_expansion_rate: 1
11031104
# The number of iterations for the Sinkhorn-Knopp algorithm.
11041105
sinkhorn_iterations: 20
1106+
1107+
################################## DeepSeek Engram ##################################
1108+
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.
1109+
# Example: [1, 4] attaches to the 2nd and 5th layer.
1110+
engram_layers: []
1111+
# The max 'n' in N-gram. Example: n=3 means it covers both 2-grams and 3-grams.
1112+
engram_max_ngram_size: 3
1113+
# Number of heads dedicated to the Engram.
1114+
engram_num_heads: 8
1115+
# Head dimension for heads.
1116+
engram_head_dim: 1280
1117+
# List of minimum head vocab sizes for each n-gram order.
1118+
engram_vocab_bases: []
1119+
# Temporal window size for Engram convolution.
1120+
engram_kernel_size: 4
1121+
# The seed for Engram hash mapping.
1122+
engram_seed: 0

src/maxtext/configs/models/deepseek-custom.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,11 @@ index_topk: 256 # Reduced from 2048
5959
# Hyper-connections: mHC enabled
6060
mhc_expansion_rate: 4
6161
sinkhorn_iterations: 20
62+
# Engram
63+
engram_layers: [1, 4]
64+
engram_num_heads: 8
65+
engram_head_dim: 512
66+
engram_vocab_bases: [226240, 226240]
67+
engram_max_ngram_size: 3
68+
engram_kernel_size: 4
69+
engram_seed: 0

src/maxtext/configs/types.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,7 @@ class RematAndOffload(BaseModel):
897897
RematLocation.REMAT,
898898
description="Remat policy for the attention output.",
899899
)
900+
engram: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the engram output.")
900901

901902
optimizer_memory_host_offload: bool = Field(False, description="Offload optimizer state to host memory.")
902903
parameter_memory_host_offload: bool = Field(False, description="Offload parameters to host memory.")
@@ -1630,6 +1631,23 @@ class SpecialTokens(BaseModel):
16301631
solution_end_token: str = Field("</answer>", description="Token to mark the end of a solution section.")
16311632

16321633

1634+
class Engram(BaseModel):
1635+
"""Configuration for DeepSeek Engram (https://www.arxiv.org/pdf/2601.07372)."""
1636+
1637+
engram_layers: list[int] = Field(
1638+
default_factory=list,
1639+
description="Indices of transformer layers where Engram are integrated.",
1640+
)
1641+
engram_num_heads: int = Field(8, description="Number of heads dedicated to the Engram.")
1642+
engram_head_dim: int = Field(1280, description="Head dimension for heads.")
1643+
engram_vocab_bases: list[int] = Field(
1644+
default_factory=list, description="List of minimum head vocab sizes for each n-gram order."
1645+
)
1646+
engram_max_ngram_size: int = Field(3, description="The max 'n' in N-gram.")
1647+
engram_kernel_size: int = Field(4, description="Temporal window size for Engram convolution.")
1648+
engram_seed: int = Field(0, description="The seed for Engram hash mapping.")
1649+
1650+
16331651
class DerivedValues(BaseModel):
16341652
"""Holds all fields that are derived from other config values for perfect legacy compatibility."""
16351653

@@ -1782,6 +1800,7 @@ class MaxTextConfig(
17821800
Quantization,
17831801
# Core Model Architecture
17841802
ModelArchitecture,
1803+
Engram,
17851804
MTP,
17861805
Logits,
17871806
# Attention Mechanisms
@@ -2262,6 +2281,18 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22622281
and self.gradient_accumulation_steps > 1
22632282
):
22642283
raise ValueError("FP8 quantization is not compatible with gradient accumulation.")
2284+
if self.engram_layers:
2285+
if not self.hf_access_token or not self.tokenizer_path:
2286+
raise ValueError(
2287+
"Engram requires both 'hf_access_token' and 'tokenizer_path' " "to load the Hugging Face tokenizer."
2288+
)
2289+
if self.scan_layers:
2290+
raise NotImplementedError("Currently Engram only supports unscanned version. Please set scan_layers=False.")
2291+
if len(self.engram_vocab_bases) != (self.engram_max_ngram_size - 1):
2292+
raise ValueError(
2293+
f"Engram vocab size mismatch: expected {self.engram_max_ngram_size - 1} (max_ngram_size - 1), "
2294+
f"but got {self.engram_vocab_bases}."
2295+
)
22652296
if self.num_experts > 1:
22662297
is_fully_moe = (
22672298
self.interleave_moe_layer_step == 1

src/maxtext/layers/decoders.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,11 +934,19 @@ def __call__(
934934
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
935935
num_layers_list = [cfg.first_num_dense_layers, num_moe_layers]
936936
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
937+
global_layer_idx_offset = 0
937938
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
938939
for index in range(num_layers):
940+
global_layer_idx = global_layer_idx_offset + index
939941
kv_cache = kv_caches[index] if kv_caches is not None else None
942+
input_tokens = decoder_input_tokens if cfg.engram_layers else None
940943
y, kv_cache = layer(
941-
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
944+
config=cfg,
945+
mesh=mesh,
946+
name=f"{layer_prefix}_{index}",
947+
quant=self.quant,
948+
model_mode=self.model_mode,
949+
layer_idx=global_layer_idx,
942950
)(
943951
y,
944952
decoder_segment_ids,
@@ -950,9 +958,11 @@ def __call__(
950958
slot=slot,
951959
kv_cache=kv_cache,
952960
attention_metadata=attention_metadata,
961+
decoder_input_tokens=input_tokens,
953962
)
954963
if kv_caches is not None and kv_cache is not None:
955964
kv_caches[index] = kv_cache
965+
global_layer_idx_offset += num_layers
956966
else:
957967
for lyr in range(cfg.num_decoder_layers):
958968
RemattedBlockLayer = RemattedBlockLayers[0]

src/maxtext/layers/engram.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@
3232
from maxtext.layers.normalizations import RMSNorm
3333
from maxtext.layers.quantizations import AqtQuantization as Quant
3434
import numpy as np
35-
from sympy import isprime
36-
from tokenizers import Regex, normalizers
35+
import sympy
36+
import tokenizers
37+
from tokenizers import normalizers
38+
3739

3840
class CompressedTokenizer:
3941
"""
@@ -50,7 +52,8 @@ class CompressedTokenizer:
5052

5153
def __init__(self, tokenizer: HFTokenizer):
5254
normalizer = self._build_normalizer()
53-
self.lookup_table, self.num_new_token = self._build_lookup_table(tokenizer, normalizer)
55+
self.lookup_table_np, self.num_new_token = self._build_lookup_table(tokenizer, normalizer)
56+
self.lookup_table = jnp.array(self.lookup_table_np, dtype=jnp.int64)
5457

5558
def __len__(self) -> int:
5659
return self.num_new_token
@@ -74,9 +77,9 @@ def _build_normalizer(self) -> normalizers.Sequence:
7477
# Lowercase conversion ("The" -> "the")
7578
normalizers.Lowercase(),
7679
# Collapse all whitespace variations to a single space
77-
normalizers.Replace(Regex(r"[ \t\r\n]+"), " "),
80+
normalizers.Replace(tokenizers.Regex(r"[ \t\r\n]+"), " "),
7881
# Protect standalone spaces from subsequent stripping
79-
normalizers.Replace(Regex(r"^ $"), SENTINEL),
82+
normalizers.Replace(tokenizers.Regex(r"^ $"), SENTINEL),
8083
# Remove leading/trailing whitespace
8184
normalizers.Strip(),
8285
# Restore protected spaces
@@ -118,19 +121,18 @@ def _build_lookup_table(self, tokenizer: HFTokenizer, normalizer: normalizers.Se
118121

119122
return old2new, len(key2new)
120123

121-
def __call__(self, input_ids) -> np.ndarray:
124+
def __call__(self, input_ids) -> Array:
122125
"""
123126
Maps original token IDs to compressed IDs.
124127
"""
125-
input_ids = np.asarray(input_ids, dtype=np.int64)
128+
input_ids = jnp.asarray(input_ids, dtype=jnp.int64)
126129

127-
# Identify valid tokens (ignore padding/masks usually marked with negative IDs)
128-
valid_mask = input_ids >= 0
129-
valid_ids = input_ids[valid_mask]
130+
# Map negative IDs to 0 for lookup, then mask output back.
131+
safe_ids = jnp.where(input_ids < 0, 0, input_ids)
132+
mapped_ids = self.lookup_table[safe_ids]
130133

131-
# Vectorized lookup: O(1) per token
132-
output_ids = input_ids.copy()
133-
output_ids[valid_mask] = self.lookup_table[valid_ids]
134+
# Restore negative IDs (padding)
135+
output_ids = jnp.where(input_ids < 0, input_ids, mapped_ids)
134136
return output_ids
135137

136138

@@ -177,11 +179,16 @@ def __init__(
177179
# Initialize compressed tokenizer
178180
self.compressed_tokenizer = CompressedTokenizer(tokenizer)
179181
self.tokenizer_vocab_size = len(self.compressed_tokenizer)
180-
if pad_id is not None:
181-
self.pad_id = int(self.compressed_tokenizer.lookup_table[pad_id])
182+
if pad_id is None:
183+
raise ValueError("The `pad_id` must be provided and cannot be None.")
184+
# Pre-calculate pad_id on CPU using numpy array to avoid ConcretizationTypeError
185+
self.pad_id = int(self.compressed_tokenizer.lookup_table_np[pad_id])
182186

183187
# Pre-calculate odd multipliers for hashing: {layer_id: multipliers}
184-
self.layer_multipliers = self._calculate_multipliers_across_layers(seed)
188+
# Store as JAX arrays
189+
self.layer_multipliers = {
190+
k: jnp.array(v, dtype=jnp.int64) for k, v in self._calculate_multipliers_across_layers(seed).items()
191+
}
185192

186193
# Pre-calculate unique prime vocab sizes for every head
187194
# Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]}
@@ -220,7 +227,7 @@ def _calculate_vocab_size_across_layers(self) -> dict[int, List[List[int]]]:
220227

221228
def find_next_unseen_prime(start: int, seen_primes: set) -> int:
222229
candidate = start + 1
223-
while candidate in seen_primes or not isprime(candidate):
230+
while candidate in seen_primes or not sympy.isprime(candidate):
224231
candidate += 1
225232
return candidate
226233

@@ -254,7 +261,7 @@ def get_vocab_sizes(self, layer_id: int) -> List[int]:
254261
"""
255262
return [head_size for ngram_size in self.vocab_size_across_layers[layer_id] for head_size in ngram_size]
256263

257-
def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.ndarray:
264+
def _get_ngram_hashes(self, compressed_ids: Array, layer_id: int) -> Array:
258265
"""
259266
Computes hash indices for all n-grams in the input batch.
260267
@@ -265,22 +272,21 @@ def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.nda
265272
Returns:
266273
hash_ids: [B, S, H_total] where H_total = H * num_ngram_orders
267274
"""
268-
x = np.asarray(compressed_ids, dtype=np.int64)
269-
B, S = x.shape
275+
x = jnp.asarray(compressed_ids, dtype=jnp.int64)
276+
B, _ = x.shape
270277

271278
# 1. Create Sliding Windows via Shifting
272279
shifted_inputs = []
273280
for k in range(self.max_ngram_size):
274281
if k == 0:
275-
# e.g., k=0, [The, cat, sat]
276282
shifted_inputs.append(x)
277283
else:
278284
# Pre-allocate full array with PAD_ID
279-
shifted_x = np.full((B, S), self.pad_id, dtype=np.int64)
285+
padding = jnp.full((B, k), self.pad_id, dtype=jnp.int64)
280286
# Fast memory copy, slicing and assignment
281287
# e.g., k=1, [PAD, The, cat]
282288
# k=2, [PAD, PAD, The]
283-
shifted_x[:, k:] = x[:, :-k]
289+
shifted_x = jnp.concatenate([padding, x[:, :-k]], axis=1)
284290
shifted_inputs.append(shifted_x)
285291

286292
# 2. Retrieve layer-specific hash multipliers
@@ -299,21 +305,21 @@ def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.nda
299305

300306
for n in range(2, self.max_ngram_size + 1):
301307
# Update hash with next history token
302-
ngram_hash = np.bitwise_xor(ngram_hash, shifted_inputs[n - 1] * multipliers[n - 1])
308+
ngram_hash = jnp.bitwise_xor(ngram_hash, shifted_inputs[n - 1] * multipliers[n - 1])
303309

304310
# Retrieve prime vocab sizes for all heads of this n-gram order
305311
vocab_sizes_for_this_gram = vocab_sizes[n - 2]
306-
mods = np.array(vocab_sizes_for_this_gram, dtype=np.int64)
312+
mods = jnp.array(vocab_sizes_for_this_gram, dtype=jnp.int64)
307313

308314
# Broadcast Modulo: Map hash to valid table indices
309315
# [B, S, 1] % [H] -> [B, S, H]
310316
head_hashes = ngram_hash[..., None] % mods
311317
all_hashes.append(head_hashes)
312318

313319
# Concatenate all heads: [B, S, H_total] where H_total = H * num_ngram_orders
314-
return np.concatenate(all_hashes, axis=2)
320+
return jnp.concatenate(all_hashes, axis=2)
315321

316-
def __call__(self, input_ids) -> dict[int, np.ndarray]:
322+
def __call__(self, input_ids) -> dict[int, Array]:
317323
# input_ids from standard tokenizer
318324
compressed_ids = self.compressed_tokenizer(input_ids)
319325
hash_ids_for_all_layers = {}
@@ -323,6 +329,13 @@ def __call__(self, input_ids) -> dict[int, np.ndarray]:
323329
return hash_ids_for_all_layers
324330

325331

332+
class StaticWrapper:
333+
"""Wrapper to prevent nnx from treating the value as a variable."""
334+
335+
def __init__(self, val):
336+
self.val = val
337+
338+
326339
class MultiHeadEmbedding(nnx.Module):
327340
"""
328341
A flattened table representation for multi-head embedding spaces across n-gram orders.
@@ -350,7 +363,7 @@ def __init__(
350363
# Compute starting index for each head's segment in the flattened table.
351364
# Offsets serve as the "base address" for each head.
352365
offsets = np.cumsum([0] + vocab_sizes[:-1]) # prefix sum
353-
self.offsets = jnp.array(offsets, dtype=jnp.int32)
366+
self.offsets = StaticWrapper(np.array(offsets, dtype=np.int64))
354367

355368
# The total embedding size is the sum of all individual head vocabularies.
356369
self.embedding = Embed(num_embeddings=sum(vocab_sizes), num_features=head_dim, config=config, mesh=mesh, rngs=rngs)
@@ -368,7 +381,7 @@ def __call__(self, input_ids: Array, model_mode: str = MODEL_MODE_TRAIN) -> Arra
368381
"""
369382
# Broadcasting Add: [B, S, H] + [H] -> [B, S, H]
370383
# Shifts local indices (0..Prime-1) to global table positions.
371-
shifted_ids = input_ids + self.offsets
384+
shifted_ids = input_ids + self.offsets.val
372385

373386
# Embedding lookup: [B, S, H_total] -> [B, S, H_total, D_head]
374387
return self.embedding(shifted_ids, model_mode=model_mode)

0 commit comments

Comments
 (0)