Skip to content

Commit 060fcd4

Browse files
Merge pull request #3410 from AI-Hypercomputer:aireen/train_tokenizer
PiperOrigin-RevId: 885251295
2 parents c6b84c1 + 6e51359 commit 060fcd4

4 files changed

Lines changed: 132 additions & 76 deletions

File tree

.github/workflows/run_tests_against_package.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,10 @@ jobs:
131131
else
132132
SPLIT_ARGS=""
133133
fi
134-
# TODO: Fix the skipped tests and remove the deselect flags
135134
.venv/bin/python3 -m pytest ${INPUTS_PYTEST_ADDOPTS} \
136135
-v \
137136
-m "${FINAL_PYTEST_MARKER}" \
138137
--durations=0 \
139-
--deselect "tests/unit/tokenizer_test.py::TokenizerTest::test_detokenize" \
140138
--cov=MaxText \
141139
--cov=maxtext \
142140
--cov-report=xml \
237 Bytes
Binary file not shown.

src/maxtext/trainers/tokenizer/train_tokenizer.py

Lines changed: 123 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,25 @@
1313
# limitations under the License.
1414

1515
""" Train tokenizer
16-
Example usage: python3 -m MaxText.train_tokenizer --dataset_path=gs://maxtext-dataset --dataset_name=c4/en:3.0.1
16+
Example usage (parquet):
17+
python3 -m MaxText.train_tokenizer \
18+
--grain_train_files=gs://my-bucket/data/*.parquet \
19+
--grain_file_type=parquet
20+
21+
Example usage (arrayrecord):
22+
python3 -m MaxText.train_tokenizer \
23+
--grain_train_files=gs://my-bucket/data/*.arrayrecord \
24+
--grain_file_type=arrayrecord \
25+
--data_column=text
1726
"""
1827

28+
import glob
1929
import os
20-
import sys
30+
import shutil
2131
import tempfile
2232
import time
33+
from collections.abc import Iterator
34+
from pathlib import Path
2335

2436
from absl import app
2537
from absl import flags
@@ -28,44 +40,101 @@
2840
from sentencepiece import SentencePieceTrainer
2941

3042
import jax
31-
32-
import tensorflow as tf
33-
import tensorflow_datasets as tfds
43+
import grain.python as grain
3444

3545
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
46+
from maxtext.utils import gcs_utils
47+
3648

37-
_DATASET_PATH = flags.DEFINE_string("dataset_path", None, "Path to the dataset", required=True)
38-
_DATASET_NAME = flags.DEFINE_string("dataset_name", None, "Name to the dataset", required=True)
49+
_GRAIN_TRAIN_FILES = flags.DEFINE_string(
50+
"grain_train_files", None, "File pattern for training data (local or gs://)", required=True
51+
)
52+
_GRAIN_FILE_TYPE = flags.DEFINE_string(
53+
"grain_file_type", "parquet", "Type of data files. Supported: 'parquet', 'arrayrecord'."
54+
)
55+
_DATA_COLUMN = flags.DEFINE_string("data_column", "text", "Column name to extract text from (used for arrayrecord).")
3956
_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size")
4057
_MAX_CORPUS_CHARS = flags.DEFINE_integer("max_corpus_chars", 10_000_000, "Max corpus chars")
41-
_ASSETS_PATH = flags.DEFINE_string("assets_path", MAXTEXT_ASSETS_ROOT, "Name to the dataset")
42-
_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Name to the dataset")
58+
_ASSETS_PATH = flags.DEFINE_string("assets_path", MAXTEXT_ASSETS_ROOT, "Path to assets directory")
59+
_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Output tokenizer model name")
60+
61+
62+
def build_grain_iterator(data_file_pattern: str, data_file_type: str, data_keys: tuple[str, ...] = ("text",)) -> Iterator:
63+
"""Build a grain iterator from a file pattern for tokenizer training.
64+
65+
Args:
66+
data_file_pattern: Glob pattern for data files (local path or gs://).
67+
data_file_type: One of 'arrayrecord' or 'parquet'.
68+
data_keys: Column names to extract from each example (used for arrayrecord).
69+
70+
Returns:
71+
A Python iterator yielding examples as dicts.
72+
"""
73+
if data_file_pattern.startswith("gs://"):
74+
data_files = gcs_utils.gcs_glob_pattern(data_file_pattern)
75+
else:
76+
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
77+
if not data_files:
78+
raise FileNotFoundError(f"No files found matching pattern: {data_file_pattern}")
79+
logging.info("Found %d files for tokenizer training.", len(data_files))
80+
81+
if data_file_type == "parquet":
82+
dataset = grain.MapDataset.source(data_files)
83+
dataset = dataset.map(grain.experimental.ParquetIterDataset)
84+
dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(data_files))
85+
return iter(dataset)
86+
elif data_file_type == "arrayrecord":
87+
from maxtext.input_pipeline.protos import example_pb2 # pylint: disable=import-outside-toplevel
88+
89+
source = grain.ArrayRecordDataSource(data_files)
90+
dataset = grain.MapDataset.source(source)
91+
92+
def _parse_example(raw_bytes):
93+
example = example_pb2.Example()
94+
example.ParseFromString(raw_bytes)
95+
features = example.features.feature
96+
parsed = {}
97+
for col in data_keys:
98+
if col in features:
99+
parsed[col] = features[col].bytes_list.value[0]
100+
return parsed
101+
102+
dataset = dataset.map(_parse_example)
103+
return iter(dataset)
104+
else:
105+
raise ValueError(f"Unsupported grain_file_type: {data_file_type!r}. Use 'parquet' or 'arrayrecord'.")
106+
43107

108+
def _dump_chars_to_textfile(dataset_iter: Iterator, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]:
109+
"""Write part of a grain dataset to lines in a text file.
44110
45-
def _dump_chars_to_textfile(dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]:
46-
"""Write part of a TFDS sentence dataset to lines in a text file.
47111
Args:
48-
dataset: tf.dataset containing string-data.
49-
maxchars: int: approximate number of characters to save from dataset.
50-
data_keys: tuple[str]: what keys in dataset to dump from.
112+
dataset_iter: Iterator yielding examples as dicts.
113+
maxchars: Approximate number of characters to save from dataset.
114+
data_keys: Keys in each example to dump.
115+
51116
Returns:
52-
name of temp file with dataset bytes, exact number of characters dumped.
117+
Name of temp file with dataset bytes, exact number of characters dumped.
53118
"""
54119
char_count = 0
55-
ds_iter = dataset.as_numpy_iterator()
56120
temp_dir = tempfile.gettempdir()
57-
with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "ds_chars")) as outfp:
121+
with tempfile.NamedTemporaryFile(
122+
delete=False, prefix=os.path.join(temp_dir, "ds_chars"), mode="w", encoding="utf-8"
123+
) as outfp:
58124
while char_count < maxchars:
59-
example = next(ds_iter)
125+
example = next(dataset_iter)
60126
for k in data_keys:
61-
line = example[k] + b"\n"
127+
val = example[k]
128+
if isinstance(val, bytes):
129+
val = val.decode("utf-8")
130+
line = val + "\n"
62131
char_count += len(line)
63132
outfp.write(line)
64133
return outfp.name, char_count
65134

66135

67136
def _train_sentencepiece(
68-
dataset: tf.data.Dataset,
137+
dataset_iter: Iterator,
69138
*,
70139
vocab_size: int,
71140
maxchars: int = int(1e7),
@@ -74,25 +143,25 @@ def _train_sentencepiece(
74143
character_coverage: float = 1.0,
75144
data_keys=("text",),
76145
):
77-
"""Train SentencePiece tokenizer from subset of tf dataset.
146+
"""Train SentencePiece tokenizer from subset of a grain dataset.
147+
78148
Args:
79-
dataset: tf.dataset
80-
vocab_size: int: size of vocab tokens to train.
81-
maxchars: int: number of characters to use for sentencepiece training.
82-
model_path: str: path of model file to save vocab model to.
83-
model_type: str: type of sentencepiece vocab to train.
84-
character_coverage: amount of characters covered by the model, good defaults
85-
are 0.9995 for languages with rich character set like Japanese or Chinese
86-
and 1.0 for other languages with small character set.
87-
data_keys: tuple[str]: keys of dataset to use for training.
149+
dataset_iter: Iterator yielding examples as dicts.
150+
vocab_size: Size of vocab tokens to train.
151+
maxchars: Number of characters to use for sentencepiece training.
152+
model_path: Path to save vocab model to (local or gs://).
153+
model_type: Type of sentencepiece vocab to train.
154+
character_coverage: Amount of characters covered by the model.
155+
data_keys: Keys of dataset to use for training.
156+
88157
Returns:
89-
path to the trained sentencepiece vocabulary model.
158+
Path to the trained sentencepiece vocabulary model.
90159
"""
91160
if model_path.startswith("gs://"):
92161
abs_model_path = model_path
93162
else:
94163
abs_model_path = os.path.abspath(os.path.expanduser(model_path))
95-
fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys)
164+
fname, _ = _dump_chars_to_textfile(dataset_iter, maxchars=maxchars, data_keys=data_keys)
96165
temp_dir = tempfile.gettempdir()
97166
with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "sp_tmp")) as model_fp:
98167
pass # we just want a prefix'd tmp-filename
@@ -107,32 +176,38 @@ def _train_sentencepiece(
107176
)
108177
SentencePieceTrainer.Train(argstr)
109178
if jax.process_index() == 0:
110-
# Use an intermediate filename that is renamed to the target name to address
111-
# create and fill delays.
112-
copy_rename_path = abs_model_path + ".rntmp"
113-
tf.io.gfile.makedirs(os.path.dirname(abs_model_path))
114-
tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True)
115-
tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True)
116-
logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path)
179+
if abs_model_path.startswith("gs://"):
180+
gcs_utils.upload_blob(abs_model_path, model_fp.name + ".model")
181+
logging.info("Uploaded %s to %s", model_fp.name + ".model", abs_model_path)
182+
else:
183+
parent = os.path.dirname(abs_model_path)
184+
if parent:
185+
os.makedirs(parent, exist_ok=True)
186+
shutil.copy(model_fp.name + ".model", abs_model_path)
187+
logging.info("Copied %s to %s", model_fp.name + ".model", abs_model_path)
117188
else:
118-
while not tf.io.gfile.exists(abs_model_path):
119-
time.sleep(1)
189+
if abs_model_path.startswith("gs://"):
190+
while not gcs_utils.gcs_path_exists(abs_model_path):
191+
time.sleep(1)
192+
else:
193+
while not os.path.exists(abs_model_path):
194+
time.sleep(1)
120195
time.sleep(1)
121196
return abs_model_path
122197

123198

124199
def train_tokenizer(
125-
dataset: tf.data.Dataset,
200+
dataset_iter: Iterator,
126201
*,
127202
vocab_path: str,
128203
vocab_size: int,
129204
max_corpus_chars: int,
130205
data_keys: tuple[str] = ("text",),
131206
):
132-
"""tokenizer training function"""
207+
"""Tokenizer training function."""
133208
logging.info("SentencePiece vocab not found, building one from data.")
134209
vocab_path = _train_sentencepiece(
135-
dataset,
210+
dataset_iter,
136211
vocab_size=vocab_size,
137212
maxchars=max_corpus_chars,
138213
model_path=vocab_path,
@@ -143,19 +218,14 @@ def train_tokenizer(
143218

144219
def main(argv):
145220
del argv
146-
flags.FLAGS(sys.argv)
147-
os.environ["TFDS_DATA_DIR"] = _DATASET_PATH.value
148-
149-
read_config = tfds.ReadConfig(
150-
shuffle_seed=0,
151-
)
152-
train_ds_builder = tfds.builder(_DATASET_NAME.value)
153-
train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True)
221+
data_keys = (_DATA_COLUMN.value,)
222+
dataset_iter = build_grain_iterator(_GRAIN_TRAIN_FILES.value, _GRAIN_FILE_TYPE.value, data_keys=data_keys)
154223
train_tokenizer(
155-
train_ds,
224+
dataset_iter,
156225
vocab_path=os.path.join(_ASSETS_PATH.value, _VOCAB_MODEL_NAME.value),
157226
vocab_size=_VOCAB_SIZE.value,
158227
max_corpus_chars=_MAX_CORPUS_CHARS.value,
228+
data_keys=data_keys,
159229
)
160230

161231

tests/unit/tokenizer_test.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,17 @@
2121

2222
import unittest
2323
import pytest
24-
import tensorflow_datasets as tfds
2524
import subprocess
2625
import os
2726

2827

29-
class TokenizerTest(unittest.TestCase):
30-
"""Tests for train_tokenizer.py"""
28+
class TrainTokenizerTest(unittest.TestCase):
29+
"""Tests for train_tokenizer.py using data from Parquet files"""
3130

3231
@classmethod
3332
def setUpClass(cls):
34-
dataset_name = "c4/en:3.0.1"
35-
dataset_path = "gs://maxtext-dataset"
33+
# the test only use ~10Mb of data, one file is enough, more files cause slow down
34+
grain_train_files = "gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet"
3635
cls.vocab_size = 32_768
3736
cls.max_corpus_chars = 10_000_000
3837
assets_path = "tests"
@@ -44,14 +43,9 @@ def setUpClass(cls):
4443
add_bos=False,
4544
add_eos=False,
4645
)
47-
os.environ["TFDS_DATA_DIR"] = dataset_path
48-
read_config = tfds.ReadConfig(
49-
shuffle_seed=0,
50-
)
51-
train_ds_builder = tfds.builder(dataset_name)
52-
cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True)
46+
dataset_iter = train_tokenizer.build_grain_iterator(grain_train_files, "parquet")
5347
train_tokenizer.train_tokenizer(
54-
cls.dataset,
48+
dataset_iter,
5549
vocab_path=cls.tokenizer_path,
5650
vocab_size=cls.vocab_size,
5751
max_corpus_chars=cls.max_corpus_chars,
@@ -76,24 +70,18 @@ def test_detokenize(self):
7670

7771

7872
class TikTokenTest(unittest.TestCase):
79-
"""Tests for train_tokenizer.py"""
73+
"""Tests for TikToken"""
8074

8175
@classmethod
8276
def setUpClass(cls):
83-
dataset_name = "c4/en:3.0.1"
84-
dataset_path = "gs://maxtext-dataset"
77+
grain_train_files = "gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet"
8578
cls.source_tokenizer = input_pipeline_utils.get_tokenizer(
8679
os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
8780
"tiktoken",
8881
add_bos=False,
8982
add_eos=False,
9083
)
91-
os.environ["TFDS_DATA_DIR"] = dataset_path
92-
read_config = tfds.ReadConfig(
93-
shuffle_seed=0,
94-
)
95-
train_ds_builder = tfds.builder(dataset_name)
96-
cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True)
84+
cls.dataset = train_tokenizer.build_grain_iterator(grain_train_files, "parquet")
9785

9886
@pytest.mark.tpu_only
9987
def test_tokenize(self):

0 commit comments

Comments
 (0)