1414
1515"""Provides op for tokenizing a dataset."""
1616
17- from typing import Iterable , Literal , Sequence , Collection
17+ from typing import Literal , Sequence , Collection
1818from pathlib import Path
19- import tensorflow as tf
20- import tensorflow_text as tftxt
2119from maxtext .utils import max_logging
2220import transformers
2321import tiktoken
2422from tiktoken .load import load_tiktoken_bpe
2523from sentencepiece import SentencePieceProcessor
2624
2725
28- Features = dict [str , tf .Tensor ]
29-
30-
3126class TikTokenTokenizer :
3227 """
3328 Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
@@ -184,33 +179,23 @@ def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int)
184179
185180class SentencePieceTokenizer :
186181 """
187- Tokenizing and encoding/decoding text using the Sentencepiece tokenizer loaded with tensorflow_text
188- """
189-
190- def __init__ (self , model_path : str , add_bos : bool , add_eos : bool ):
191- max_logging .log (f"Tokenizer path: { model_path } " )
192- with tf .io .gfile .GFile (model_path , "rb" ) as model_fp :
193- sp_model = model_fp .read ()
194- self .sp_tokenizer = tftxt .SentencepieceTokenizer (model = sp_model , add_bos = add_bos , add_eos = add_eos , reverse = False )
195- self .pad_id = self .sp_tokenizer .string_to_id ("<pad>" )
196- self .unk_id = self .sp_tokenizer .string_to_id ("<unk>" )
197-
198- def encode (self , s : str ) -> list [int ]:
199- return self .sp_tokenizer .tokenize (s )
200-
201- def decode (self , t : Sequence [int ]) -> str :
202- return self .sp_tokenizer .detokenize (t )
203-
204-
205- class SentencePieceTokenizerGrain :
206- """
207- Tokenizing and encoding/decoding text using the Sentencepiece tokenizer loaded with sentencepiece
182+ Tokenizing and encoding/decoding text using the native sentencepiece library.
183+ Supports both local and GCS (gs://) model paths.
208184 """
209185
210186 def __init__ (self , model_path : str , add_bos : bool , add_eos : bool ):
211187 max_logging .log (f"Loading sentencepiece tokenizer: { model_path } " )
212188 self ._tokenizer_model = SentencePieceProcessor ()
213- self ._tokenizer_model .Load (model_path )
189+ try :
190+ if model_path .startswith ("gs://" ):
191+ from maxtext .utils .gcs_utils import read_bytes_from_gcs # pylint: disable=import-outside-toplevel
192+
193+ model_proto = read_bytes_from_gcs (model_path )
194+ self ._tokenizer_model .LoadFromSerializedProto (model_proto )
195+ else :
196+ self ._tokenizer_model .Load (model_path )
197+ except Exception as e :
198+ raise ValueError (f"Failed to load sentencepiece tokenizer from { model_path } : { e } " ) from e
214199 self .pad_id = self ._tokenizer_model .pad_id ()
215200 self .unk_id = self ._tokenizer_model .unk_id ()
216201 self .bos_id = self ._tokenizer_model .bos_id ()
@@ -255,7 +240,7 @@ def decode(self, t: Sequence[int]) -> str:
255240 return self .tokenizer .decode (t )
256241
257242
258- def build_tokenizer (tokenizer_path , tokenizer_type , add_bos , add_eos , hf_access_token , dataset_type ):
243+ def build_tokenizer (tokenizer_path , tokenizer_type , add_bos , add_eos , hf_access_token ):
259244 """Loads the tokenizer at `tokenizer_path`"""
260245 max_logging .log (f"Tokenizer path: { tokenizer_path } " )
261246 if tokenizer_type == "tiktoken" :
@@ -264,27 +249,6 @@ def build_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_
264249 elif tokenizer_type == "huggingface" :
265250 return HFTokenizer (tokenizer_path , add_bos , add_eos , hf_access_token )
266251 elif tokenizer_type == "sentencepiece" :
267- if dataset_type == "tfds" :
268- return SentencePieceTokenizer (tokenizer_path , add_bos , add_eos )
269- else :
270- return SentencePieceTokenizerGrain (tokenizer_path , add_bos , add_eos )
252+ return SentencePieceTokenizer (tokenizer_path , add_bos , add_eos )
271253 else :
272254 raise ValueError (f"Invalid tokenizer_type:{ tokenizer_type } chosen in config" )
273-
274-
275- def TokenizeOp (tokenizer , features : Features , data_keys : Iterable [str ] = ("inputs" , "targets" )) -> Features :
276- """Op for tokenization"""
277-
278- def _process_string (string_tensor ):
279- # Extract string value and decode it if necessary
280- string_value = string_tensor .numpy ().decode ("utf-8" )
281- # encode and extract the tokenized integers
282- modified_string = tokenizer .encode (string_value )
283- return [modified_string ]
284-
285- for k in data_keys :
286- if isinstance (tokenizer , (TikTokenTokenizer , HFTokenizer )):
287- features [k ] = tf .py_function (_process_string , [features [k ]], Tout = [tf .int32 ])[0 ]
288- elif isinstance (tokenizer , SentencePieceTokenizer ):
289- features [k ] = tokenizer .encode (features [k ])
290- return features
0 commit comments