3232from maxtext .layers .normalizations import RMSNorm
3333from maxtext .layers .quantizations import AqtQuantization as Quant
3434import numpy as np
35- from sympy import isprime
36- from tokenizers import Regex , normalizers
35+ import sympy
36+ import tokenizers
37+ from tokenizers import normalizers
38+
3739
3840class CompressedTokenizer :
3941 """
@@ -50,7 +52,8 @@ class CompressedTokenizer:
5052
5153 def __init__ (self , tokenizer : HFTokenizer ):
5254 normalizer = self ._build_normalizer ()
53- self .lookup_table , self .num_new_token = self ._build_lookup_table (tokenizer , normalizer )
55+ self .lookup_table_np , self .num_new_token = self ._build_lookup_table (tokenizer , normalizer )
56+ self .lookup_table = jnp .array (self .lookup_table_np , dtype = jnp .int64 )
5457
5558 def __len__ (self ) -> int :
5659 return self .num_new_token
@@ -74,9 +77,9 @@ def _build_normalizer(self) -> normalizers.Sequence:
7477 # Lowercase conversion ("The" -> "the")
7578 normalizers .Lowercase (),
7679 # Collapse all whitespace variations to a single space
77- normalizers .Replace (Regex (r"[ \t\r\n]+" ), " " ),
80+ normalizers .Replace (tokenizers . Regex (r"[ \t\r\n]+" ), " " ),
7881 # Protect standalone spaces from subsequent stripping
79- normalizers .Replace (Regex (r"^ $" ), SENTINEL ),
82+ normalizers .Replace (tokenizers . Regex (r"^ $" ), SENTINEL ),
8083 # Remove leading/trailing whitespace
8184 normalizers .Strip (),
8285 # Restore protected spaces
@@ -118,19 +121,18 @@ def _build_lookup_table(self, tokenizer: HFTokenizer, normalizer: normalizers.Se
118121
119122 return old2new , len (key2new )
120123
121- def __call__ (self , input_ids ) -> np . ndarray :
124+ def __call__ (self , input_ids ) -> Array :
122125 """
123126 Maps original token IDs to compressed IDs.
124127 """
125- input_ids = np .asarray (input_ids , dtype = np .int64 )
128+ input_ids = jnp .asarray (input_ids , dtype = jnp .int64 )
126129
127- # Identify valid tokens (ignore padding/masks usually marked with negative IDs)
128- valid_mask = input_ids >= 0
129- valid_ids = input_ids [ valid_mask ]
130+ # Map negative IDs to 0 for lookup, then mask output back.
131+ safe_ids = jnp . where ( input_ids < 0 , 0 , input_ids )
132+ mapped_ids = self . lookup_table [ safe_ids ]
130133
131- # Vectorized lookup: O(1) per token
132- output_ids = input_ids .copy ()
133- output_ids [valid_mask ] = self .lookup_table [valid_ids ]
134+ # Restore negative IDs (padding)
135+ output_ids = jnp .where (input_ids < 0 , input_ids , mapped_ids )
134136 return output_ids
135137
136138
@@ -177,11 +179,16 @@ def __init__(
177179 # Initialize compressed tokenizer
178180 self .compressed_tokenizer = CompressedTokenizer (tokenizer )
179181 self .tokenizer_vocab_size = len (self .compressed_tokenizer )
180- if pad_id is not None :
181- self .pad_id = int (self .compressed_tokenizer .lookup_table [pad_id ])
182+ if pad_id is None :
183+ raise ValueError ("The `pad_id` must be provided and cannot be None." )
184+ # Pre-calculate pad_id on CPU using numpy array to avoid ConcretizationTypeError
185+ self .pad_id = int (self .compressed_tokenizer .lookup_table_np [pad_id ])
182186
183187 # Pre-calculate odd multipliers for hashing: {layer_id: multipliers}
184- self .layer_multipliers = self ._calculate_multipliers_across_layers (seed )
188+ # Store as JAX arrays
189+ self .layer_multipliers = {
190+ k : jnp .array (v , dtype = jnp .int64 ) for k , v in self ._calculate_multipliers_across_layers (seed ).items ()
191+ }
185192
186193 # Pre-calculate unique prime vocab sizes for every head
187194 # Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]}
@@ -220,7 +227,7 @@ def _calculate_vocab_size_across_layers(self) -> dict[int, List[List[int]]]:
220227
221228 def find_next_unseen_prime (start : int , seen_primes : set ) -> int :
222229 candidate = start + 1
223- while candidate in seen_primes or not isprime (candidate ):
230+ while candidate in seen_primes or not sympy . isprime (candidate ):
224231 candidate += 1
225232 return candidate
226233
@@ -254,7 +261,7 @@ def get_vocab_sizes(self, layer_id: int) -> List[int]:
254261 """
255262 return [head_size for ngram_size in self .vocab_size_across_layers [layer_id ] for head_size in ngram_size ]
256263
257- def _get_ngram_hashes (self , compressed_ids : np . ndarray , layer_id : int ) -> np . ndarray :
264+ def _get_ngram_hashes (self , compressed_ids : Array , layer_id : int ) -> Array :
258265 """
259266 Computes hash indices for all n-grams in the input batch.
260267
@@ -265,22 +272,21 @@ def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.nda
265272 Returns:
266273 hash_ids: [B, S, H_total] where H_total = H * num_ngram_orders
267274 """
268- x = np .asarray (compressed_ids , dtype = np .int64 )
269- B , S = x .shape
275+ x = jnp .asarray (compressed_ids , dtype = jnp .int64 )
276+ B , _ = x .shape
270277
271278 # 1. Create Sliding Windows via Shifting
272279 shifted_inputs = []
273280 for k in range (self .max_ngram_size ):
274281 if k == 0 :
275- # e.g., k=0, [The, cat, sat]
276282 shifted_inputs .append (x )
277283 else :
278284 # Pre-allocate full array with PAD_ID
279- shifted_x = np .full ((B , S ), self .pad_id , dtype = np .int64 )
285+ padding = jnp .full ((B , k ), self .pad_id , dtype = jnp .int64 )
280286 # Fast memory copy, slicing and assignment
281287 # e.g., k=1, [PAD, The, cat]
282288 # k=2, [PAD, PAD, The]
283- shifted_x [:, k :] = x [:, :- k ]
289+ shifted_x = jnp . concatenate ([ padding , x [:, :- k ]], axis = 1 )
284290 shifted_inputs .append (shifted_x )
285291
286292 # 2. Retrieve layer-specific hash multipliers
@@ -299,21 +305,21 @@ def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.nda
299305
300306 for n in range (2 , self .max_ngram_size + 1 ):
301307 # Update hash with next history token
302- ngram_hash = np .bitwise_xor (ngram_hash , shifted_inputs [n - 1 ] * multipliers [n - 1 ])
308+ ngram_hash = jnp .bitwise_xor (ngram_hash , shifted_inputs [n - 1 ] * multipliers [n - 1 ])
303309
304310 # Retrieve prime vocab sizes for all heads of this n-gram order
305311 vocab_sizes_for_this_gram = vocab_sizes [n - 2 ]
306- mods = np .array (vocab_sizes_for_this_gram , dtype = np .int64 )
312+ mods = jnp .array (vocab_sizes_for_this_gram , dtype = jnp .int64 )
307313
308314 # Broadcast Modulo: Map hash to valid table indices
309315 # [B, S, 1] % [H] -> [B, S, H]
310316 head_hashes = ngram_hash [..., None ] % mods
311317 all_hashes .append (head_hashes )
312318
313319 # Concatenate all heads: [B, S, H_total] where H_total = H * num_ngram_orders
314- return np .concatenate (all_hashes , axis = 2 )
320+ return jnp .concatenate (all_hashes , axis = 2 )
315321
316- def __call__ (self , input_ids ) -> dict [int , np . ndarray ]:
322+ def __call__ (self , input_ids ) -> dict [int , Array ]:
317323 # input_ids from standard tokenizer
318324 compressed_ids = self .compressed_tokenizer (input_ids )
319325 hash_ids_for_all_layers = {}
@@ -323,6 +329,13 @@ def __call__(self, input_ids) -> dict[int, np.ndarray]:
323329 return hash_ids_for_all_layers
324330
325331
332+ class StaticWrapper :
333+ """Wrapper to prevent nnx from treating the value as a variable."""
334+
335+ def __init__ (self , val ):
336+ self .val = val
337+
338+
326339class MultiHeadEmbedding (nnx .Module ):
327340 """
328341 A flattened table representation for multi-head embedding spaces across n-gram orders.
@@ -350,7 +363,7 @@ def __init__(
350363 # Compute starting index for each head's segment in the flattened table.
351364 # Offsets serve as the "base address" for each head.
352365 offsets = np .cumsum ([0 ] + vocab_sizes [:- 1 ]) # prefix sum
353- self .offsets = jnp .array (offsets , dtype = jnp . int32 )
366+ self .offsets = StaticWrapper ( np .array (offsets , dtype = np . int64 ) )
354367
355368 # The total embedding size is the sum of all individual head vocabularies.
356369 self .embedding = Embed (num_embeddings = sum (vocab_sizes ), num_features = head_dim , config = config , mesh = mesh , rngs = rngs )
@@ -368,7 +381,7 @@ def __call__(self, input_ids: Array, model_mode: str = MODEL_MODE_TRAIN) -> Arra
368381 """
369382 # Broadcasting Add: [B, S, H] + [H] -> [B, S, H]
370383 # Shifts local indices (0..Prime-1) to global table positions.
371- shifted_ids = input_ids + self .offsets
384+ shifted_ids = input_ids + self .offsets . val
372385
373386 # Embedding lookup: [B, S, H_total] -> [B, S, H_total, D_head]
374387 return self .embedding (shifted_ids , model_mode = model_mode )
0 commit comments