Skip to content

Commit 6e51359

Browse files
committed
Migrate train_tokenizer off TF
1 parent d561ad4 commit 6e51359

4 files changed

Lines changed: 133 additions & 76 deletions

File tree

.github/workflows/run_tests_against_package.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,10 @@ jobs:
121121
else
122122
SPLIT_ARGS=""
123123
fi
124-
# TODO: Fix the skipped tests and remove the deselect flags
125124
.venv/bin/python3 -m pytest ${INPUTS_PYTEST_ADDOPTS} \
126125
-v \
127126
-m "${FINAL_PYTEST_MARKER}" \
128127
--durations=0 \
129-
--deselect "tests/unit/tokenizer_test.py::TokenizerTest::test_detokenize" \
130128
--cov=MaxText \
131129
--cov=maxtext \
132130
--cov-report=xml \
237 Bytes
Binary file not shown.

src/maxtext/trainers/tokenizer/train_tokenizer.py

Lines changed: 124 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,25 @@
1313
# limitations under the License.
1414

1515
""" Train tokenizer
16-
Example usage: python3 -m MaxText.train_tokenizer --dataset_path=gs://maxtext-dataset --dataset_name=c4/en:3.0.1
16+
Example usage (parquet):
17+
python3 -m MaxText.train_tokenizer \
18+
--grain_train_files=gs://my-bucket/data/*.parquet \
19+
--grain_file_type=parquet
20+
21+
Example usage (arrayrecord):
22+
python3 -m MaxText.train_tokenizer \
23+
--grain_train_files=gs://my-bucket/data/*.arrayrecord \
24+
--grain_file_type=arrayrecord \
25+
--data_column=text
1726
"""
1827

28+
import glob
1929
import os
20-
import sys
30+
import shutil
2131
import tempfile
2232
import time
33+
from collections.abc import Iterator
34+
from pathlib import Path
2335

2436
from absl import app
2537
from absl import flags
@@ -28,44 +40,102 @@
2840
from sentencepiece import SentencePieceTrainer
2941

3042
import jax
31-
32-
import tensorflow as tf
33-
import tensorflow_datasets as tfds
43+
import grain.python as grain
44+
import grain.experimental
3445

3546
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
47+
from maxtext.utils import gcs_utils
48+
3649

37-
_DATASET_PATH = flags.DEFINE_string("dataset_path", None, "Path to the dataset", required=True)
38-
_DATASET_NAME = flags.DEFINE_string("dataset_name", None, "Name to the dataset", required=True)
50+
_GRAIN_TRAIN_FILES = flags.DEFINE_string(
51+
"grain_train_files", None, "File pattern for training data (local or gs://)", required=True
52+
)
53+
_GRAIN_FILE_TYPE = flags.DEFINE_string(
54+
"grain_file_type", "parquet", "Type of data files. Supported: 'parquet', 'arrayrecord'."
55+
)
56+
_DATA_COLUMN = flags.DEFINE_string("data_column", "text", "Column name to extract text from (used for arrayrecord).")
3957
_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size")
4058
_MAX_CORPUS_CHARS = flags.DEFINE_integer("max_corpus_chars", 10_000_000, "Max corpus chars")
41-
_ASSETS_PATH = flags.DEFINE_string("assets_path", MAXTEXT_ASSETS_ROOT, "Name to the dataset")
42-
_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Name to the dataset")
59+
_ASSETS_PATH = flags.DEFINE_string("assets_path", MAXTEXT_ASSETS_ROOT, "Path to assets directory")
60+
_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Output tokenizer model name")
61+
62+
63+
def build_grain_iterator(data_file_pattern: str, data_file_type: str, data_keys: tuple[str, ...] = ("text",)) -> Iterator:
64+
"""Build a grain iterator from a file pattern for tokenizer training.
65+
66+
Args:
67+
data_file_pattern: Glob pattern for data files (local path or gs://).
68+
data_file_type: One of 'arrayrecord' or 'parquet'.
69+
data_keys: Column names to extract from each example (used for arrayrecord).
70+
71+
Returns:
72+
A Python iterator yielding examples as dicts.
73+
"""
74+
if data_file_pattern.startswith("gs://"):
75+
data_files = gcs_utils.gcs_glob_pattern(data_file_pattern)
76+
else:
77+
data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve()))
78+
if not data_files:
79+
raise FileNotFoundError(f"No files found matching pattern: {data_file_pattern}")
80+
logging.info("Found %d files for tokenizer training.", len(data_files))
81+
82+
if data_file_type == "parquet":
83+
dataset = grain.MapDataset.source(data_files)
84+
dataset = dataset.map(grain.experimental.ParquetIterDataset)
85+
dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(data_files))
86+
return iter(dataset)
87+
elif data_file_type == "arrayrecord":
88+
from maxtext.input_pipeline.protos import example_pb2 # pylint: disable=import-outside-toplevel
89+
90+
source = grain.ArrayRecordDataSource(data_files)
91+
dataset = grain.MapDataset.source(source)
92+
93+
def _parse_example(raw_bytes):
94+
example = example_pb2.Example()
95+
example.ParseFromString(raw_bytes)
96+
features = example.features.feature
97+
parsed = {}
98+
for col in data_keys:
99+
if col in features:
100+
parsed[col] = features[col].bytes_list.value[0]
101+
return parsed
102+
103+
dataset = dataset.map(_parse_example)
104+
return iter(dataset)
105+
else:
106+
raise ValueError(f"Unsupported grain_file_type: {data_file_type!r}. Use 'parquet' or 'arrayrecord'.")
107+
43108

109+
def _dump_chars_to_textfile(dataset_iter: Iterator, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]:
110+
"""Write part of a grain dataset to lines in a text file.
44111
45-
def _dump_chars_to_textfile(dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]:
46-
"""Write part of a TFDS sentence dataset to lines in a text file.
47112
Args:
48-
dataset: tf.dataset containing string-data.
49-
maxchars: int: approximate number of characters to save from dataset.
50-
data_keys: tuple[str]: what keys in dataset to dump from.
113+
dataset_iter: Iterator yielding examples as dicts.
114+
maxchars: Approximate number of characters to save from dataset.
115+
data_keys: Keys in each example to dump.
116+
51117
Returns:
52-
name of temp file with dataset bytes, exact number of characters dumped.
118+
Name of temp file with dataset bytes, exact number of characters dumped.
53119
"""
54120
char_count = 0
55-
ds_iter = dataset.as_numpy_iterator()
56121
temp_dir = tempfile.gettempdir()
57-
with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "ds_chars")) as outfp:
122+
with tempfile.NamedTemporaryFile(
123+
delete=False, prefix=os.path.join(temp_dir, "ds_chars"), mode="w", encoding="utf-8"
124+
) as outfp:
58125
while char_count < maxchars:
59-
example = next(ds_iter)
126+
example = next(dataset_iter)
60127
for k in data_keys:
61-
line = example[k] + b"\n"
128+
val = example[k]
129+
if isinstance(val, bytes):
130+
val = val.decode("utf-8")
131+
line = val + "\n"
62132
char_count += len(line)
63133
outfp.write(line)
64134
return outfp.name, char_count
65135

66136

67137
def _train_sentencepiece(
68-
dataset: tf.data.Dataset,
138+
dataset_iter: Iterator,
69139
*,
70140
vocab_size: int,
71141
maxchars: int = int(1e7),
@@ -74,25 +144,25 @@ def _train_sentencepiece(
74144
character_coverage: float = 1.0,
75145
data_keys=("text",),
76146
):
77-
"""Train SentencePiece tokenizer from subset of tf dataset.
147+
"""Train SentencePiece tokenizer from subset of a grain dataset.
148+
78149
Args:
79-
dataset: tf.dataset
80-
vocab_size: int: size of vocab tokens to train.
81-
maxchars: int: number of characters to use for sentencepiece training.
82-
model_path: str: path of model file to save vocab model to.
83-
model_type: str: type of sentencepiece vocab to train.
84-
character_coverage: amount of characters covered by the model, good defaults
85-
are 0.9995 for languages with rich character set like Japanese or Chinese
86-
and 1.0 for other languages with small character set.
87-
data_keys: tuple[str]: keys of dataset to use for training.
150+
dataset_iter: Iterator yielding examples as dicts.
151+
vocab_size: Size of vocab tokens to train.
152+
maxchars: Number of characters to use for sentencepiece training.
153+
model_path: Path to save vocab model to (local or gs://).
154+
model_type: Type of sentencepiece vocab to train.
155+
character_coverage: Amount of characters covered by the model.
156+
data_keys: Keys of dataset to use for training.
157+
88158
Returns:
89-
path to the trained sentencepiece vocabulary model.
159+
Path to the trained sentencepiece vocabulary model.
90160
"""
91161
if model_path.startswith("gs://"):
92162
abs_model_path = model_path
93163
else:
94164
abs_model_path = os.path.abspath(os.path.expanduser(model_path))
95-
fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys)
165+
fname, _ = _dump_chars_to_textfile(dataset_iter, maxchars=maxchars, data_keys=data_keys)
96166
temp_dir = tempfile.gettempdir()
97167
with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "sp_tmp")) as model_fp:
98168
pass # we just want a prefix'd tmp-filename
@@ -107,32 +177,38 @@ def _train_sentencepiece(
107177
)
108178
SentencePieceTrainer.Train(argstr)
109179
if jax.process_index() == 0:
110-
# Use an intermediate filename that is renamed to the target name to address
111-
# create and fill delays.
112-
copy_rename_path = abs_model_path + ".rntmp"
113-
tf.io.gfile.makedirs(os.path.dirname(abs_model_path))
114-
tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True)
115-
tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True)
116-
logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path)
180+
if abs_model_path.startswith("gs://"):
181+
gcs_utils.upload_blob(abs_model_path, model_fp.name + ".model")
182+
logging.info("Uploaded %s to %s", model_fp.name + ".model", abs_model_path)
183+
else:
184+
parent = os.path.dirname(abs_model_path)
185+
if parent:
186+
os.makedirs(parent, exist_ok=True)
187+
shutil.copy(model_fp.name + ".model", abs_model_path)
188+
logging.info("Copied %s to %s", model_fp.name + ".model", abs_model_path)
117189
else:
118-
while not tf.io.gfile.exists(abs_model_path):
119-
time.sleep(1)
190+
if abs_model_path.startswith("gs://"):
191+
while not gcs_utils.gcs_path_exists(abs_model_path):
192+
time.sleep(1)
193+
else:
194+
while not os.path.exists(abs_model_path):
195+
time.sleep(1)
120196
time.sleep(1)
121197
return abs_model_path
122198

123199

124200
def train_tokenizer(
125-
dataset: tf.data.Dataset,
201+
dataset_iter: Iterator,
126202
*,
127203
vocab_path: str,
128204
vocab_size: int,
129205
max_corpus_chars: int,
130206
data_keys: tuple[str] = ("text",),
131207
):
132-
"""tokenizer training function"""
208+
"""Tokenizer training function."""
133209
logging.info("SentencePiece vocab not found, building one from data.")
134210
vocab_path = _train_sentencepiece(
135-
dataset,
211+
dataset_iter,
136212
vocab_size=vocab_size,
137213
maxchars=max_corpus_chars,
138214
model_path=vocab_path,
@@ -143,19 +219,14 @@ def train_tokenizer(
143219

144220
def main(argv):
145221
del argv
146-
flags.FLAGS(sys.argv)
147-
os.environ["TFDS_DATA_DIR"] = _DATASET_PATH.value
148-
149-
read_config = tfds.ReadConfig(
150-
shuffle_seed=0,
151-
)
152-
train_ds_builder = tfds.builder(_DATASET_NAME.value)
153-
train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True)
222+
data_keys = (_DATA_COLUMN.value,)
223+
dataset_iter = build_grain_iterator(_GRAIN_TRAIN_FILES.value, _GRAIN_FILE_TYPE.value, data_keys=data_keys)
154224
train_tokenizer(
155-
train_ds,
225+
dataset_iter,
156226
vocab_path=os.path.join(_ASSETS_PATH.value, _VOCAB_MODEL_NAME.value),
157227
vocab_size=_VOCAB_SIZE.value,
158228
max_corpus_chars=_MAX_CORPUS_CHARS.value,
229+
data_keys=data_keys,
159230
)
160231

161232

tests/unit/tokenizer_test.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,17 @@
2121

2222
import unittest
2323
import pytest
24-
import tensorflow_datasets as tfds
2524
import subprocess
2625
import os
2726

2827

29-
class TokenizerTest(unittest.TestCase):
30-
"""Tests for train_tokenizer.py"""
28+
class TrainTokenizerTest(unittest.TestCase):
29+
"""Tests for train_tokenizer.py using data from Parquet files"""
3130

3231
@classmethod
3332
def setUpClass(cls):
34-
dataset_name = "c4/en:3.0.1"
35-
dataset_path = "gs://maxtext-dataset"
33+
# the test only use ~10Mb of data, one file is enough, more files cause slow down
34+
grain_train_files = "gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet"
3635
cls.vocab_size = 32_768
3736
cls.max_corpus_chars = 10_000_000
3837
assets_path = "tests"
@@ -44,14 +43,9 @@ def setUpClass(cls):
4443
add_bos=False,
4544
add_eos=False,
4645
)
47-
os.environ["TFDS_DATA_DIR"] = dataset_path
48-
read_config = tfds.ReadConfig(
49-
shuffle_seed=0,
50-
)
51-
train_ds_builder = tfds.builder(dataset_name)
52-
cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True)
46+
dataset_iter = train_tokenizer.build_grain_iterator(grain_train_files, "parquet")
5347
train_tokenizer.train_tokenizer(
54-
cls.dataset,
48+
dataset_iter,
5549
vocab_path=cls.tokenizer_path,
5650
vocab_size=cls.vocab_size,
5751
max_corpus_chars=cls.max_corpus_chars,
@@ -76,24 +70,18 @@ def test_detokenize(self):
7670

7771

7872
class TikTokenTest(unittest.TestCase):
79-
"""Tests for train_tokenizer.py"""
73+
"""Tests for TikToken"""
8074

8175
@classmethod
8276
def setUpClass(cls):
83-
dataset_name = "c4/en:3.0.1"
84-
dataset_path = "gs://maxtext-dataset"
77+
grain_train_files = "gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet"
8578
cls.source_tokenizer = input_pipeline_utils.get_tokenizer(
8679
os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"),
8780
"tiktoken",
8881
add_bos=False,
8982
add_eos=False,
9083
)
91-
os.environ["TFDS_DATA_DIR"] = dataset_path
92-
read_config = tfds.ReadConfig(
93-
shuffle_seed=0,
94-
)
95-
train_ds_builder = tfds.builder(dataset_name)
96-
cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True)
84+
cls.dataset = train_tokenizer.build_grain_iterator(grain_train_files, "parquet")
9785

9886
@pytest.mark.tpu_only
9987
def test_tokenize(self):

0 commit comments

Comments
 (0)