3535from sympy import isprime
3636from tokenizers import Regex , normalizers
3737
38+
3839class CompressedTokenizer :
3940 """
4041 A canonicalizing wrapper that reduces vocabulary sparsity for n-gram lookup.
@@ -50,7 +51,8 @@ class CompressedTokenizer:
5051
5152 def __init__ (self , tokenizer : HFTokenizer ):
5253 normalizer = self ._build_normalizer ()
53- self .lookup_table , self .num_new_token = self ._build_lookup_table (tokenizer , normalizer )
54+ self .lookup_table_np , self .num_new_token = self ._build_lookup_table (tokenizer , normalizer )
55+ self .lookup_table = jnp .array (self .lookup_table_np , dtype = jnp .int64 )
5456
5557 def __len__ (self ) -> int :
5658 return self .num_new_token
@@ -118,19 +120,18 @@ def _build_lookup_table(self, tokenizer: HFTokenizer, normalizer: normalizers.Se
118120
119121 return old2new , len (key2new )
120122
121- def __call__ (self , input_ids ) -> np . ndarray :
123+ def __call__ (self , input_ids ) -> Array :
122124 """
123125 Maps original token IDs to compressed IDs.
124126 """
125- input_ids = np .asarray (input_ids , dtype = np .int64 )
127+ input_ids = jnp .asarray (input_ids , dtype = jnp .int64 )
126128
127- # Identify valid tokens (ignore padding/masks usually marked with negative IDs)
128- valid_mask = input_ids >= 0
129- valid_ids = input_ids [ valid_mask ]
129+ # Map negative IDs to 0 for lookup, then mask output back.
130+ safe_ids = jnp . where ( input_ids < 0 , 0 , input_ids )
131+ mapped_ids = self . lookup_table [ safe_ids ]
130132
131- # Vectorized lookup: O(1) per token
132- output_ids = input_ids .copy ()
133- output_ids [valid_mask ] = self .lookup_table [valid_ids ]
133+ # Restore negative IDs (padding)
134+ output_ids = jnp .where (input_ids < 0 , input_ids , mapped_ids )
134135 return output_ids
135136
136137
@@ -177,11 +178,16 @@ def __init__(
177178 # Initialize compressed tokenizer
178179 self .compressed_tokenizer = CompressedTokenizer (tokenizer )
179180 self .tokenizer_vocab_size = len (self .compressed_tokenizer )
180- if pad_id is not None :
181- self .pad_id = int (self .compressed_tokenizer .lookup_table [pad_id ])
181+ if pad_id is None :
182+ raise ValueError ("The `pad_id` must be provided and cannot be None." )
183+ # Pre-calculate pad_id on CPU using numpy array to avoid ConcretizationTypeError
184+ self .pad_id = int (self .compressed_tokenizer .lookup_table_np [pad_id ])
182185
183186 # Pre-calculate odd multipliers for hashing: {layer_id: multipliers}
184- self .layer_multipliers = self ._calculate_multipliers_across_layers (seed )
187+ # Store as JAX arrays
188+ self .layer_multipliers = {
189+ k : jnp .array (v , dtype = jnp .int64 ) for k , v in self ._calculate_multipliers_across_layers (seed ).items ()
190+ }
185191
186192 # Pre-calculate unique prime vocab sizes for every head
187193 # Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]}
@@ -254,7 +260,7 @@ def get_vocab_sizes(self, layer_id: int) -> List[int]:
254260 """
255261 return [head_size for ngram_size in self .vocab_size_across_layers [layer_id ] for head_size in ngram_size ]
256262
257- def _get_ngram_hashes (self , compressed_ids : np . ndarray , layer_id : int ) -> np . ndarray :
263+ def _get_ngram_hashes (self , compressed_ids : Array , layer_id : int ) -> Array :
258264 """
259265 Computes hash indices for all n-grams in the input batch.
260266
@@ -265,22 +271,21 @@ def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.nda
265271 Returns:
266272 hash_ids: [B, S, H_total] where H_total = H * num_ngram_orders
267273 """
268- x = np .asarray (compressed_ids , dtype = np .int64 )
269- B , S = x .shape
274+ x = jnp .asarray (compressed_ids , dtype = jnp .int64 )
275+ B , _ = x .shape
270276
271277 # 1. Create Sliding Windows via Shifting
272278 shifted_inputs = []
273279 for k in range (self .max_ngram_size ):
274280 if k == 0 :
275- # e.g., k=0, [The, cat, sat]
276281 shifted_inputs .append (x )
277282 else :
278283 # Pre-allocate full array with PAD_ID
279- shifted_x = np .full ((B , S ), self .pad_id , dtype = np .int64 )
284+ padding = jnp .full ((B , k ), self .pad_id , dtype = jnp .int64 )
280285 # Fast memory copy, slicing and assignment
281286 # e.g., k=1, [PAD, The, cat]
282287 # k=2, [PAD, PAD, The]
283- shifted_x [:, k :] = x [:, :- k ]
288+ shifted_x = jnp . concatenate ([ padding , x [:, :- k ]], axis = 1 )
284289 shifted_inputs .append (shifted_x )
285290
286291 # 2. Retrieve layer-specific hash multipliers
@@ -299,21 +304,21 @@ def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.nda
299304
300305 for n in range (2 , self .max_ngram_size + 1 ):
301306 # Update hash with next history token
302- ngram_hash = np .bitwise_xor (ngram_hash , shifted_inputs [n - 1 ] * multipliers [n - 1 ])
307+ ngram_hash = jnp .bitwise_xor (ngram_hash , shifted_inputs [n - 1 ] * multipliers [n - 1 ])
303308
304309 # Retrieve prime vocab sizes for all heads of this n-gram order
305310 vocab_sizes_for_this_gram = vocab_sizes [n - 2 ]
306- mods = np .array (vocab_sizes_for_this_gram , dtype = np .int64 )
311+ mods = jnp .array (vocab_sizes_for_this_gram , dtype = jnp .int64 )
307312
308313 # Broadcast Modulo: Map hash to valid table indices
309314 # [B, S, 1] % [H] -> [B, S, H]
310315 head_hashes = ngram_hash [..., None ] % mods
311316 all_hashes .append (head_hashes )
312317
313318 # Concatenate all heads: [B, S, H_total] where H_total = H * num_ngram_orders
314- return np .concatenate (all_hashes , axis = 2 )
319+ return jnp .concatenate (all_hashes , axis = 2 )
315320
316- def __call__ (self , input_ids ) -> dict [int , np . ndarray ]:
321+ def __call__ (self , input_ids ) -> dict [int , Array ]:
317322 # input_ids from standard tokenizer
318323 compressed_ids = self .compressed_tokenizer (input_ids )
319324 hash_ids_for_all_layers = {}
@@ -323,6 +328,13 @@ def __call__(self, input_ids) -> dict[int, np.ndarray]:
323328 return hash_ids_for_all_layers
324329
325330
331+ class StaticWrapper :
332+ """Wrapper to prevent nnx from treating the value as a variable."""
333+
334+ def __init__ (self , val ):
335+ self .val = val
336+
337+
326338class MultiHeadEmbedding (nnx .Module ):
327339 """
328340 A flattened table representation for multi-head embedding spaces across n-gram orders.
@@ -350,7 +362,7 @@ def __init__(
350362 # Compute starting index for each head's segment in the flattened table.
351363 # Offsets serve as the "base address" for each head.
352364 offsets = np .cumsum ([0 ] + vocab_sizes [:- 1 ]) # prefix sum
353- self .offsets = jnp .array (offsets , dtype = jnp . int32 )
365+ self .offsets = StaticWrapper ( np .array (offsets , dtype = np . int64 ) )
354366
355367 # The total embedding size is the sum of all individual head vocabularies.
356368 self .embedding = Embed (num_embeddings = sum (vocab_sizes ), num_features = head_dim , config = config , mesh = mesh , rngs = rngs )
@@ -368,7 +380,7 @@ def __call__(self, input_ids: Array, model_mode: str = MODEL_MODE_TRAIN) -> Arra
368380 """
369381 # Broadcasting Add: [B, S, H] + [H] -> [B, S, H]
370382 # Shifts local indices (0..Prime-1) to global table positions.
371- shifted_ids = input_ids + self .offsets
383+ shifted_ids = input_ids + self .offsets . val
372384
373385 # Embedding lookup: [B, S, H_total] -> [B, S, H_total, D_head]
374386 return self .embedding (shifted_ids , model_mode = model_mode )
0 commit comments