Skip to content

Commit 708c406

Browse files
Merge pull request #3220 from AI-Hypercomputer:bvandermoon-repo-restructure
PiperOrigin-RevId: 874382885
2 parents 7f0028b + 89b8dd9 commit 708c406

3 files changed

Lines changed: 180 additions & 142 deletions

File tree

src/MaxText/train_tokenizer.py

Lines changed: 16 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,150 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Train tokenizer
16-
Example usage: python3 -m MaxText.train_tokenizer --dataset_path=gs://maxtext-dataset --dataset_name=c4/en:3.0.1
17-
"""
15+
"""Shim for `train_tokenizer` in `src/maxtext/trainers/tokenizer`."""
1816

19-
import os
20-
import tempfile
21-
import time
22-
23-
from absl import app
24-
from absl import flags
2517
from absl import logging
2618

27-
from sentencepiece import SentencePieceTrainer
28-
29-
import jax
30-
31-
import tensorflow as tf
32-
import tensorflow_datasets as tfds
33-
34-
from MaxText.globals import MAXTEXT_ASSETS_ROOT
35-
36-
_DATASET_PATH = flags.DEFINE_string("dataset_path", None, "Path to the dataset", required=True)
37-
_DATASET_NAME = flags.DEFINE_string("dataset_name", None, "Name to the dataset", required=True)
38-
_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size")
39-
_MAX_CORPUS_CHARS = flags.DEFINE_integer("max_corpus_chars", 10_000_000, "Max corpus chars")
40-
_ASSETS_PATH = flags.DEFINE_string("assets_path", MAXTEXT_ASSETS_ROOT, "Name to the dataset")
41-
_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Name to the dataset")
42-
43-
44-
def _dump_chars_to_textfile(dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]:
45-
"""Write part of a TFDS sentence dataset to lines in a text file.
46-
Args:
47-
dataset: tf.dataset containing string-data.
48-
maxchars: int: approximate number of characters to save from dataset.
49-
data_keys: tuple[str]: what keys in dataset to dump from.
50-
Returns:
51-
name of temp file with dataset bytes, exact number of characters dumped.
52-
"""
53-
char_count = 0
54-
ds_iter = dataset.as_numpy_iterator()
55-
temp_dir = tempfile.gettempdir()
56-
with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "ds_chars")) as outfp:
57-
while char_count < maxchars:
58-
example = next(ds_iter)
59-
for k in data_keys:
60-
line = example[k] + b"\n"
61-
char_count += len(line)
62-
outfp.write(line)
63-
return outfp.name, char_count
64-
65-
66-
def _train_sentencepiece(
67-
dataset: tf.data.Dataset,
68-
*,
69-
vocab_size: int,
70-
maxchars: int = int(1e7),
71-
model_path: str,
72-
model_type: str = "unigram",
73-
character_coverage: float = 1.0,
74-
data_keys=("text",),
75-
):
76-
"""Train SentencePiece tokenizer from subset of tf dataset.
77-
Args:
78-
dataset: tf.dataset
79-
vocab_size: int: size of vocab tokens to train.
80-
maxchars: int: number of characters to use for sentencepiece training.
81-
model_path: str: path of model file to save vocab model to.
82-
model_type: str: type of sentencepiece vocab to train.
83-
character_coverage: amount of characters covered by the model, good defaults
84-
are 0.9995 for languages with rich character set like Japanese or Chinese
85-
and 1.0 for other languages with small character set.
86-
data_keys: tuple[str]: keys of dataset to use for training.
87-
Returns:
88-
path to the trained sentencepiece vocabulary model.
89-
"""
90-
if model_path.startswith("gs://"):
91-
abs_model_path = model_path
92-
else:
93-
abs_model_path = os.path.abspath(os.path.expanduser(model_path))
94-
fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys)
95-
temp_dir = tempfile.gettempdir()
96-
with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "sp_tmp")) as model_fp:
97-
pass # we just want a prefix'd tmp-filename
98-
argstr = " ".join(
99-
[
100-
f"--input={fname}",
101-
f"--vocab_size={vocab_size}",
102-
f"--character_coverage={character_coverage}",
103-
f"--model_prefix={model_fp.name}",
104-
f"--model_type={model_type}",
105-
]
106-
)
107-
SentencePieceTrainer.Train(argstr)
108-
if jax.process_index() == 0:
109-
# Use an intermediate filename that is renamed to the target name to address
110-
# create and fill delays.
111-
copy_rename_path = abs_model_path + ".rntmp"
112-
tf.io.gfile.makedirs(os.path.dirname(abs_model_path))
113-
tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True)
114-
tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True)
115-
logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path)
116-
else:
117-
while not tf.io.gfile.exists(abs_model_path):
118-
time.sleep(1)
119-
time.sleep(1)
120-
return abs_model_path
121-
122-
123-
def train_tokenizer(
124-
dataset: tf.data.Dataset,
125-
*,
126-
vocab_path: str,
127-
vocab_size: int,
128-
max_corpus_chars: int,
129-
data_keys: tuple[str] = ("text",),
130-
):
131-
"""tokenizer training function"""
132-
logging.info("SentencePiece vocab not found, building one from data.")
133-
vocab_path = _train_sentencepiece(
134-
dataset,
135-
vocab_size=vocab_size,
136-
maxchars=max_corpus_chars,
137-
model_path=vocab_path,
138-
data_keys=data_keys,
139-
)
140-
logging.info("Model saved at %s", vocab_path)
141-
142-
143-
def main(argv):
144-
del argv
145-
os.environ["TFDS_DATA_DIR"] = _DATASET_PATH.value
19+
import importlib
20+
import sys
14621

147-
read_config = tfds.ReadConfig(
148-
shuffle_seed=0,
149-
)
150-
train_ds_builder = tfds.builder(_DATASET_NAME.value)
151-
train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True)
152-
train_tokenizer(
153-
train_ds,
154-
vocab_path=os.path.join(_ASSETS_PATH.value, _VOCAB_MODEL_NAME.value),
155-
vocab_size=_VOCAB_SIZE.value,
156-
max_corpus_chars=_MAX_CORPUS_CHARS.value,
157-
)
22+
from maxtext.utils import max_logging
15823

24+
OLD_MODULE_PATH = "MaxText.train_tokenizer"
25+
NEW_MODULE_PATH = "maxtext.trainers.tokenizer.train_tokenizer"
15926

16027
if __name__ == "__main__":
161-
app.run(main)
28+
try:
29+
logging.set_verbosity(logging.INFO)
30+
_new_module = importlib.import_module(NEW_MODULE_PATH)
31+
if hasattr(_new_module, "main"):
32+
max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n")
33+
_new_module.main(sys.argv)
34+
except ImportError as e:
35+
max_logging.error(f"Shim could not find target module: '{NEW_MODULE_PATH}'\n")
36+
raise e
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
""" Train tokenizer
16+
Example usage: python3 -m MaxText.train_tokenizer --dataset_path=gs://maxtext-dataset --dataset_name=c4/en:3.0.1
17+
"""
18+
19+
import os
20+
import sys
21+
import tempfile
22+
import time
23+
24+
from absl import app
25+
from absl import flags
26+
from absl import logging
27+
28+
from sentencepiece import SentencePieceTrainer
29+
30+
import jax
31+
32+
import tensorflow as tf
33+
import tensorflow_datasets as tfds
34+
35+
from MaxText.globals import MAXTEXT_ASSETS_ROOT
36+
37+
_DATASET_PATH = flags.DEFINE_string("dataset_path", None, "Path to the dataset", required=True)
38+
_DATASET_NAME = flags.DEFINE_string("dataset_name", None, "Name to the dataset", required=True)
39+
_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size")
40+
_MAX_CORPUS_CHARS = flags.DEFINE_integer("max_corpus_chars", 10_000_000, "Max corpus chars")
41+
_ASSETS_PATH = flags.DEFINE_string("assets_path", MAXTEXT_ASSETS_ROOT, "Name to the dataset")
42+
_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Name to the dataset")
43+
44+
45+
def _dump_chars_to_textfile(dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]:
46+
"""Write part of a TFDS sentence dataset to lines in a text file.
47+
Args:
48+
dataset: tf.dataset containing string-data.
49+
maxchars: int: approximate number of characters to save from dataset.
50+
data_keys: tuple[str]: what keys in dataset to dump from.
51+
Returns:
52+
name of temp file with dataset bytes, exact number of characters dumped.
53+
"""
54+
char_count = 0
55+
ds_iter = dataset.as_numpy_iterator()
56+
temp_dir = tempfile.gettempdir()
57+
with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "ds_chars")) as outfp:
58+
while char_count < maxchars:
59+
example = next(ds_iter)
60+
for k in data_keys:
61+
line = example[k] + b"\n"
62+
char_count += len(line)
63+
outfp.write(line)
64+
return outfp.name, char_count
65+
66+
67+
def _train_sentencepiece(
68+
dataset: tf.data.Dataset,
69+
*,
70+
vocab_size: int,
71+
maxchars: int = int(1e7),
72+
model_path: str,
73+
model_type: str = "unigram",
74+
character_coverage: float = 1.0,
75+
data_keys=("text",),
76+
):
77+
"""Train SentencePiece tokenizer from subset of tf dataset.
78+
Args:
79+
dataset: tf.dataset
80+
vocab_size: int: size of vocab tokens to train.
81+
maxchars: int: number of characters to use for sentencepiece training.
82+
model_path: str: path of model file to save vocab model to.
83+
model_type: str: type of sentencepiece vocab to train.
84+
character_coverage: amount of characters covered by the model, good defaults
85+
are 0.9995 for languages with rich character set like Japanese or Chinese
86+
and 1.0 for other languages with small character set.
87+
data_keys: tuple[str]: keys of dataset to use for training.
88+
Returns:
89+
path to the trained sentencepiece vocabulary model.
90+
"""
91+
if model_path.startswith("gs://"):
92+
abs_model_path = model_path
93+
else:
94+
abs_model_path = os.path.abspath(os.path.expanduser(model_path))
95+
fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys)
96+
temp_dir = tempfile.gettempdir()
97+
with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "sp_tmp")) as model_fp:
98+
pass # we just want a prefix'd tmp-filename
99+
argstr = " ".join(
100+
[
101+
f"--input={fname}",
102+
f"--vocab_size={vocab_size}",
103+
f"--character_coverage={character_coverage}",
104+
f"--model_prefix={model_fp.name}",
105+
f"--model_type={model_type}",
106+
]
107+
)
108+
SentencePieceTrainer.Train(argstr)
109+
if jax.process_index() == 0:
110+
# Use an intermediate filename that is renamed to the target name to address
111+
# create and fill delays.
112+
copy_rename_path = abs_model_path + ".rntmp"
113+
tf.io.gfile.makedirs(os.path.dirname(abs_model_path))
114+
tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True)
115+
tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True)
116+
logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path)
117+
else:
118+
while not tf.io.gfile.exists(abs_model_path):
119+
time.sleep(1)
120+
time.sleep(1)
121+
return abs_model_path
122+
123+
124+
def train_tokenizer(
125+
dataset: tf.data.Dataset,
126+
*,
127+
vocab_path: str,
128+
vocab_size: int,
129+
max_corpus_chars: int,
130+
data_keys: tuple[str] = ("text",),
131+
):
132+
"""tokenizer training function"""
133+
logging.info("SentencePiece vocab not found, building one from data.")
134+
vocab_path = _train_sentencepiece(
135+
dataset,
136+
vocab_size=vocab_size,
137+
maxchars=max_corpus_chars,
138+
model_path=vocab_path,
139+
data_keys=data_keys,
140+
)
141+
logging.info("Model saved at %s", vocab_path)
142+
143+
144+
def main(argv):
145+
del argv
146+
flags.FLAGS(sys.argv)
147+
os.environ["TFDS_DATA_DIR"] = _DATASET_PATH.value
148+
149+
read_config = tfds.ReadConfig(
150+
shuffle_seed=0,
151+
)
152+
train_ds_builder = tfds.builder(_DATASET_NAME.value)
153+
train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True)
154+
train_tokenizer(
155+
train_ds,
156+
vocab_path=os.path.join(_ASSETS_PATH.value, _VOCAB_MODEL_NAME.value),
157+
vocab_size=_VOCAB_SIZE.value,
158+
max_corpus_chars=_MAX_CORPUS_CHARS.value,
159+
)
160+
161+
162+
if __name__ == "__main__":
163+
app.run(main)

tests/unit/tokenizer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
"""Tests for tokenizer"""
1616

1717
import numpy as np
18-
from MaxText import train_tokenizer
1918
from MaxText.globals import MAXTEXT_ASSETS_ROOT
2019
from maxtext.input_pipeline import input_pipeline_utils
20+
from maxtext.trainers.tokenizer import train_tokenizer
2121

2222
import unittest
2323
import pytest

0 commit comments

Comments
 (0)