Skip to content

Commit 7963fc8

Browse files
committed
remove tf dependency from tokenizer.py
1 parent 0f59a69 commit 7963fc8

9 files changed

Lines changed: 65 additions & 76 deletions

File tree

src/maxtext/input_pipeline/grain_data_processing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def pretrain_preprocessing_pipeline(
213213
config.add_bos,
214214
config.add_eos,
215215
config.hf_access_token,
216-
config.dataset_type,
217216
)
218217
if tokenizer_model.pad_id is not None:
219218
pad_id = tokenizer_model.pad_id
@@ -321,7 +320,6 @@ def dpo_preprocessing_pipeline(
321320
config.add_bos,
322321
config.add_eos,
323322
config.hf_access_token,
324-
config.dataset_type,
325323
)
326324
if tokenizer_model.pad_id is not None:
327325
pad_id = tokenizer_model.pad_id

src/maxtext/input_pipeline/grain_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TokenizerTransformBase:
3030
# pylint: disable=attribute-defined-outside-init
3131
feature_names: str | Sequence[str]
3232
sequence_length: int | Sequence[int]
33-
tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer
33+
tokenizer: tokenizer.SentencePieceTokenizer | tokenizer.HFTokenizer
3434

3535
def __post_init__(self):
3636
self._processor = None

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import dataclasses
1818
import warnings
1919
from threading import current_thread
20-
from typing import Any, TYPE_CHECKING
20+
from typing import Any, Iterable, TYPE_CHECKING
2121

2222
if TYPE_CHECKING:
2323
import datasets
@@ -40,11 +40,9 @@ def normalize_features(x, column_name):
4040
return {"inputs": x[column_name], "targets": x[column_name]}
4141

4242

43-
def get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token=None, dataset_type="tfds"):
43+
def get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token=None):
4444
# Load tokenizer
45-
tokenizer_model = tokenizer.build_tokenizer(
46-
tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token, dataset_type
47-
)
45+
tokenizer_model = tokenizer.build_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token)
4846
return tokenizer_model
4947

5048

@@ -67,6 +65,21 @@ def add_segmentation_and_position(x, data_columns, padding_token=0):
6765
return x
6866

6967

68+
def TokenizeOp(tokenizer_model, features: Features, data_keys: Iterable[str] = ("inputs", "targets")) -> Features:
69+
"""Op for tokenization"""
70+
71+
def _process_string(string_tensor):
72+
# Extract string value and decode it if necessary
73+
string_value = string_tensor.numpy().decode("utf-8")
74+
# encode and extract the tokenized integers
75+
modified_string = tokenizer_model.encode(string_value)
76+
return [modified_string]
77+
78+
for k in data_keys:
79+
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
80+
return features
81+
82+
7083
########## Functions used by HF pipeline
7184

7285

src/maxtext/input_pipeline/tfds_data_processing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import jax
2626

2727
from maxtext.input_pipeline import multihost_dataloading
28-
from maxtext.input_pipeline import tokenizer
2928
from maxtext.input_pipeline.packing import sequence_packing
3029
from maxtext.input_pipeline import input_pipeline_utils
3130

@@ -116,7 +115,9 @@ def preprocessing_pipeline(
116115

117116
if tokenize:
118117
dataset = dataset.map(
119-
lambda x: tokenizer.TokenizeOp(tokenizer=tokenizer_model, features=x, data_keys=data_column_names),
118+
lambda x: input_pipeline_utils.TokenizeOp(
119+
tokenizer_model=tokenizer_model, features=x, data_keys=data_column_names
120+
),
120121
num_parallel_calls=AUTOTUNE,
121122
)
122123

src/maxtext/input_pipeline/tfds_data_processing_c4_mlperf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
import jax.numpy as jnp
2828
from jax.experimental import multihost_utils
2929

30-
from maxtext.input_pipeline import tokenizer
3130
from maxtext.input_pipeline import multihost_dataloading
3231
from maxtext.input_pipeline.packing import sequence_packing
3332
from maxtext.input_pipeline.input_pipeline_utils import get_tokenizer
33+
from maxtext.input_pipeline.input_pipeline_utils import TokenizeOp
3434
from maxtext.utils import max_logging
3535

3636
AUTOTUNE = tf.data.experimental.AUTOTUNE
@@ -258,7 +258,7 @@ def preprocess_train_dataset(
258258
else:
259259
pad_id = -1
260260
train_ds = train_ds.map(
261-
lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)),
261+
lambda x: TokenizeOp(tokenizer_model=sp_tokenizer, features=x, data_keys=("targets",)),
262262
num_parallel_calls=AUTOTUNE,
263263
)
264264
train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096)
@@ -283,7 +283,7 @@ def preprocess_eval_dataset(
283283
# group text up to max_target_length if the dataset is not pre-tokenized/pre-processed
284284
if not is_tokenized_dataset:
285285
eval_ds = eval_ds.map(
286-
lambda x: tokenizer.TokenizeOp(tokenizer=sp_tokenizer, features=x, data_keys=("targets",)),
286+
lambda x: TokenizeOp(tokenizer_model=sp_tokenizer, features=x, data_keys=("targets",)),
287287
num_parallel_calls=AUTOTUNE,
288288
)
289289
# hardcode batch_sizes 24567 i.e. the exp size in split validation_24567exp

src/maxtext/input_pipeline/tokenizer.py

Lines changed: 15 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,15 @@
1414

1515
"""Provides op for tokenizing a dataset."""
1616

17-
from typing import Iterable, Literal, Sequence, Collection
17+
from typing import Literal, Sequence, Collection
1818
from pathlib import Path
19-
import tensorflow as tf
20-
import tensorflow_text as tftxt
2119
from maxtext.utils import max_logging
2220
import transformers
2321
import tiktoken
2422
from tiktoken.load import load_tiktoken_bpe
2523
from sentencepiece import SentencePieceProcessor
2624

2725

28-
Features = dict[str, tf.Tensor]
29-
30-
3126
class TikTokenTokenizer:
3227
"""
3328
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
@@ -184,33 +179,23 @@ def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int)
184179

185180
class SentencePieceTokenizer:
186181
"""
187-
Tokenizing and encoding/decoding text using the Sentencepiece tokenizer loaded with tensorflow_text
188-
"""
189-
190-
def __init__(self, model_path: str, add_bos: bool, add_eos: bool):
191-
max_logging.log(f"Tokenizer path: {model_path}")
192-
with tf.io.gfile.GFile(model_path, "rb") as model_fp:
193-
sp_model = model_fp.read()
194-
self.sp_tokenizer = tftxt.SentencepieceTokenizer(model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False)
195-
self.pad_id = self.sp_tokenizer.string_to_id("<pad>")
196-
self.unk_id = self.sp_tokenizer.string_to_id("<unk>")
197-
198-
def encode(self, s: str) -> list[int]:
199-
return self.sp_tokenizer.tokenize(s)
200-
201-
def decode(self, t: Sequence[int]) -> str:
202-
return self.sp_tokenizer.detokenize(t)
203-
204-
205-
class SentencePieceTokenizerGrain:
206-
"""
207-
Tokenizing and encoding/decoding text using the Sentencepiece tokenizer loaded with sentencepiece
182+
Tokenizing and encoding/decoding text using the native sentencepiece library.
183+
Supports both local and GCS (gs://) model paths.
208184
"""
209185

210186
def __init__(self, model_path: str, add_bos: bool, add_eos: bool):
211187
max_logging.log(f"Loading sentencepiece tokenizer: {model_path}")
212188
self._tokenizer_model = SentencePieceProcessor()
213-
self._tokenizer_model.Load(model_path)
189+
try:
190+
if model_path.startswith("gs://"):
191+
from maxtext.utils.gcs_utils import read_bytes_from_gcs # pylint: disable=import-outside-toplevel
192+
193+
model_proto = read_bytes_from_gcs(model_path)
194+
self._tokenizer_model.LoadFromSerializedProto(model_proto)
195+
else:
196+
self._tokenizer_model.Load(model_path)
197+
except Exception as e:
198+
raise ValueError(f"Failed to load sentencepiece tokenizer from {model_path}: {e}") from e
214199
self.pad_id = self._tokenizer_model.pad_id()
215200
self.unk_id = self._tokenizer_model.unk_id()
216201
self.bos_id = self._tokenizer_model.bos_id()
@@ -255,7 +240,7 @@ def decode(self, t: Sequence[int]) -> str:
255240
return self.tokenizer.decode(t)
256241

257242

258-
def build_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token, dataset_type):
243+
def build_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token):
259244
"""Loads the tokenizer at `tokenizer_path`"""
260245
max_logging.log(f"Tokenizer path: {tokenizer_path}")
261246
if tokenizer_type == "tiktoken":
@@ -264,27 +249,6 @@ def build_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_
264249
elif tokenizer_type == "huggingface":
265250
return HFTokenizer(tokenizer_path, add_bos, add_eos, hf_access_token)
266251
elif tokenizer_type == "sentencepiece":
267-
if dataset_type == "tfds":
268-
return SentencePieceTokenizer(tokenizer_path, add_bos, add_eos)
269-
else:
270-
return SentencePieceTokenizerGrain(tokenizer_path, add_bos, add_eos)
252+
return SentencePieceTokenizer(tokenizer_path, add_bos, add_eos)
271253
else:
272254
raise ValueError(f"Invalid tokenizer_type:{tokenizer_type} chosen in config")
273-
274-
275-
def TokenizeOp(tokenizer, features: Features, data_keys: Iterable[str] = ("inputs", "targets")) -> Features:
276-
"""Op for tokenization"""
277-
278-
def _process_string(string_tensor):
279-
# Extract string value and decode it if necessary
280-
string_value = string_tensor.numpy().decode("utf-8")
281-
# encode and extract the tokenized integers
282-
modified_string = tokenizer.encode(string_value)
283-
return [modified_string]
284-
285-
for k in data_keys:
286-
if isinstance(tokenizer, (TikTokenTokenizer, HFTokenizer)):
287-
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
288-
elif isinstance(tokenizer, SentencePieceTokenizer):
289-
features[k] = tokenizer.encode(features[k])
290-
return features

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,6 @@ def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyco
408408
add_bos=student_config.add_bos,
409409
add_eos=student_config.add_eos,
410410
hf_access_token=student_config.hf_access_token,
411-
dataset_type=student_config.dataset_type,
412411
)
413412
pad_id = tok.pad_id if tok.pad_id is not None else 0
414413

src/maxtext/utils/gcs_utils.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,31 +177,45 @@ def gcs_glob_pattern(pattern):
177177
return data_files
178178

179179

180-
def read_json_from_gcs(file_path):
181-
"""
182-
Read a json file from gcs bucket.
180+
def read_bytes_from_gcs(file_path):
181+
"""Read raw bytes from a GCS file.
183182
184183
Args:
185-
file_path: The gcs path of the json file.
184+
file_path: The gcs path of the file (e.g. gs://bucket/path/to/file).
186185
187186
Returns:
188-
A dictionary with content from json file.
187+
The file contents as bytes, or None if unavailable.
189188
"""
190-
if not _gcs_guard("read_json_from_gcs"):
189+
if not _gcs_guard("read_bytes_from_gcs"):
191190
return None
192191
try:
193192
storage_client = storage.Client()
194193
bucket_name, file_prefix = parse_gcs_bucket_and_prefix(file_path)
195194
bucket = storage_client.bucket(bucket_name)
196195
blob = bucket.blob(file_prefix)
196+
return blob.download_as_bytes()
197+
except Exception as e: # pylint: disable=broad-except
198+
max_logging.log(f"Error reading bytes from GCS path {file_path}: {e}")
199+
return None
197200

198-
json_string = blob.download_as_string()
199201

200-
data = json.loads(json_string)
202+
def read_json_from_gcs(file_path):
203+
"""
204+
Read a json file from gcs bucket.
205+
206+
Args:
207+
file_path: The gcs path of the json file.
201208
202-
return data
209+
Returns:
210+
A dictionary with content from json file.
211+
"""
212+
try:
213+
raw = read_bytes_from_gcs(file_path)
214+
if raw is None:
215+
return None
216+
return json.loads(raw)
203217
except (ValueError, TypeError, json.JSONDecodeError) as e:
204-
print(f"Error reading JSON file from GCS: {str(e)}")
218+
max_logging.log(f"Error reading JSON file from GCS: {str(e)}")
205219
return None
206220

207221

tests/unit/tokenizer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def tearDownClass(cls):
6767
@pytest.mark.tpu_only
6868
def test_tokenize(self):
6969
text = "This is a test"
70-
self.assertTrue(np.array_equal(self.source_tokenizer.encode(text).numpy(), self.test_tokenizer.encode(text).numpy()))
70+
self.assertTrue(np.array_equal(self.source_tokenizer.encode(text), self.test_tokenizer.encode(text)))
7171

7272
@pytest.mark.tpu_only
7373
def test_detokenize(self):

0 commit comments

Comments
 (0)