|
1 | | -# Copyright 2023–2025 Google LLC |
| 1 | +# Copyright 2023–2026 Google LLC |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -""" Train tokenizer |
16 | | -Example usage: python3 -m MaxText.train_tokenizer --dataset_path=gs://maxtext-dataset --dataset_name=c4/en:3.0.1 |
17 | | -""" |
| 15 | +"""Shim for `train_tokenizer` in `src/maxtext/trainers/tokenizer`.""" |
18 | 16 |
|
19 | | -import os |
20 | | -import tempfile |
21 | | -import time |
22 | | - |
23 | | -from absl import app |
24 | | -from absl import flags |
25 | 17 | from absl import logging |
26 | 18 |
|
27 | | -from sentencepiece import SentencePieceTrainer |
28 | | - |
29 | | -import jax |
30 | | - |
31 | | -import tensorflow as tf |
32 | | -import tensorflow_datasets as tfds |
33 | | - |
34 | | -from MaxText.globals import MAXTEXT_ASSETS_ROOT |
35 | | - |
36 | | -_DATASET_PATH = flags.DEFINE_string("dataset_path", None, "Path to the dataset", required=True) |
37 | | -_DATASET_NAME = flags.DEFINE_string("dataset_name", None, "Name to the dataset", required=True) |
38 | | -_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size") |
39 | | -_MAX_CORPUS_CHARS = flags.DEFINE_integer("max_corpus_chars", 10_000_000, "Max corpus chars") |
40 | | -_ASSETS_PATH = flags.DEFINE_string("assets_path", MAXTEXT_ASSETS_ROOT, "Name to the dataset") |
41 | | -_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Name to the dataset") |
42 | | - |
43 | | - |
44 | | -def _dump_chars_to_textfile(dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]: |
45 | | - """Write part of a TFDS sentence dataset to lines in a text file. |
46 | | - Args: |
47 | | - dataset: tf.dataset containing string-data. |
48 | | - maxchars: int: approximate number of characters to save from dataset. |
49 | | - data_keys: tuple[str]: what keys in dataset to dump from. |
50 | | - Returns: |
51 | | - name of temp file with dataset bytes, exact number of characters dumped. |
52 | | - """ |
53 | | - char_count = 0 |
54 | | - ds_iter = dataset.as_numpy_iterator() |
55 | | - temp_dir = tempfile.gettempdir() |
56 | | - with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "ds_chars")) as outfp: |
57 | | - while char_count < maxchars: |
58 | | - example = next(ds_iter) |
59 | | - for k in data_keys: |
60 | | - line = example[k] + b"\n" |
61 | | - char_count += len(line) |
62 | | - outfp.write(line) |
63 | | - return outfp.name, char_count |
64 | | - |
65 | | - |
66 | | -def _train_sentencepiece( |
67 | | - dataset: tf.data.Dataset, |
68 | | - *, |
69 | | - vocab_size: int, |
70 | | - maxchars: int = int(1e7), |
71 | | - model_path: str, |
72 | | - model_type: str = "unigram", |
73 | | - character_coverage: float = 1.0, |
74 | | - data_keys=("text",), |
75 | | -): |
76 | | - """Train SentencePiece tokenizer from subset of tf dataset. |
77 | | - Args: |
78 | | - dataset: tf.dataset |
79 | | - vocab_size: int: size of vocab tokens to train. |
80 | | - maxchars: int: number of characters to use for sentencepiece training. |
81 | | - model_path: str: path of model file to save vocab model to. |
82 | | - model_type: str: type of sentencepiece vocab to train. |
83 | | - character_coverage: amount of characters covered by the model, good defaults |
84 | | - are 0.9995 for languages with rich character set like Japanese or Chinese |
85 | | - and 1.0 for other languages with small character set. |
86 | | - data_keys: tuple[str]: keys of dataset to use for training. |
87 | | - Returns: |
88 | | - path to the trained sentencepiece vocabulary model. |
89 | | - """ |
90 | | - if model_path.startswith("gs://"): |
91 | | - abs_model_path = model_path |
92 | | - else: |
93 | | - abs_model_path = os.path.abspath(os.path.expanduser(model_path)) |
94 | | - fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) |
95 | | - temp_dir = tempfile.gettempdir() |
96 | | - with tempfile.NamedTemporaryFile(delete=False, prefix=os.path.join(temp_dir, "sp_tmp")) as model_fp: |
97 | | - pass # we just want a prefix'd tmp-filename |
98 | | - argstr = " ".join( |
99 | | - [ |
100 | | - f"--input={fname}", |
101 | | - f"--vocab_size={vocab_size}", |
102 | | - f"--character_coverage={character_coverage}", |
103 | | - f"--model_prefix={model_fp.name}", |
104 | | - f"--model_type={model_type}", |
105 | | - ] |
106 | | - ) |
107 | | - SentencePieceTrainer.Train(argstr) |
108 | | - if jax.process_index() == 0: |
109 | | - # Use an intermediate filename that is renamed to the target name to address |
110 | | - # create and fill delays. |
111 | | - copy_rename_path = abs_model_path + ".rntmp" |
112 | | - tf.io.gfile.makedirs(os.path.dirname(abs_model_path)) |
113 | | - tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True) |
114 | | - tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) |
115 | | - logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path) |
116 | | - else: |
117 | | - while not tf.io.gfile.exists(abs_model_path): |
118 | | - time.sleep(1) |
119 | | - time.sleep(1) |
120 | | - return abs_model_path |
121 | | - |
122 | | - |
123 | | -def train_tokenizer( |
124 | | - dataset: tf.data.Dataset, |
125 | | - *, |
126 | | - vocab_path: str, |
127 | | - vocab_size: int, |
128 | | - max_corpus_chars: int, |
129 | | - data_keys: tuple[str] = ("text",), |
130 | | -): |
131 | | - """tokenizer training function""" |
132 | | - logging.info("SentencePiece vocab not found, building one from data.") |
133 | | - vocab_path = _train_sentencepiece( |
134 | | - dataset, |
135 | | - vocab_size=vocab_size, |
136 | | - maxchars=max_corpus_chars, |
137 | | - model_path=vocab_path, |
138 | | - data_keys=data_keys, |
139 | | - ) |
140 | | - logging.info("Model saved at %s", vocab_path) |
141 | | - |
142 | | - |
143 | | -def main(argv): |
144 | | - del argv |
145 | | - os.environ["TFDS_DATA_DIR"] = _DATASET_PATH.value |
| 19 | +import importlib |
| 20 | +import sys |
146 | 21 |
|
147 | | - read_config = tfds.ReadConfig( |
148 | | - shuffle_seed=0, |
149 | | - ) |
150 | | - train_ds_builder = tfds.builder(_DATASET_NAME.value) |
151 | | - train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) |
152 | | - train_tokenizer( |
153 | | - train_ds, |
154 | | - vocab_path=os.path.join(_ASSETS_PATH.value, _VOCAB_MODEL_NAME.value), |
155 | | - vocab_size=_VOCAB_SIZE.value, |
156 | | - max_corpus_chars=_MAX_CORPUS_CHARS.value, |
157 | | - ) |
| 22 | +from maxtext.utils import max_logging |
158 | 23 |
|
| 24 | +OLD_MODULE_PATH = "MaxText.train_tokenizer" |
| 25 | +NEW_MODULE_PATH = "maxtext.trainers.tokenizer.train_tokenizer" |
159 | 26 |
|
160 | 27 | if __name__ == "__main__": |
161 | | - app.run(main) |
| 28 | + try: |
| 29 | + logging.set_verbosity(logging.INFO) |
| 30 | + _new_module = importlib.import_module(NEW_MODULE_PATH) |
| 31 | + if hasattr(_new_module, "main"): |
| 32 | + max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n") |
| 33 | + _new_module.main(sys.argv) |
| 34 | + except ImportError as e: |
| 35 | + max_logging.error(f"Shim could not find target module: '{NEW_MODULE_PATH}'\n") |
| 36 | + raise e |
0 commit comments