1313# limitations under the License.
1414
1515""" Train tokenizer
16- Example usage: python3 -m MaxText.train_tokenizer --dataset_path=gs://maxtext-dataset --dataset_name=c4/en:3.0.1
16+ Example usage (parquet):
17+ python3 -m MaxText.train_tokenizer \
18+ --grain_train_files=gs://my-bucket/data/*.parquet \
19+ --grain_file_type=parquet
20+
21+ Example usage (arrayrecord):
22+ python3 -m MaxText.train_tokenizer \
23+ --grain_train_files=gs://my-bucket/data/*.arrayrecord \
24+ --grain_file_type=arrayrecord \
25+ --data_column=text
1726"""
1827
28+ import glob
1929import os
20- import sys
30+ import shutil
2131import tempfile
2232import time
33+ from collections .abc import Iterator
34+ from pathlib import Path
2335
2436from absl import app
2537from absl import flags
2840from sentencepiece import SentencePieceTrainer
2941
3042import jax
31-
32- import tensorflow as tf
33- import tensorflow_datasets as tfds
43+ import grain .python as grain
44+ import grain .experimental
3445
3546from maxtext .utils .globals import MAXTEXT_ASSETS_ROOT
47+ from maxtext .utils import gcs_utils
48+
3649
37- _DATASET_PATH = flags .DEFINE_string ("dataset_path" , None , "Path to the dataset" , required = True )
38- _DATASET_NAME = flags .DEFINE_string ("dataset_name" , None , "Name to the dataset" , required = True )
50+ _GRAIN_TRAIN_FILES = flags .DEFINE_string (
51+ "grain_train_files" , None , "File pattern for training data (local or gs://)" , required = True
52+ )
53+ _GRAIN_FILE_TYPE = flags .DEFINE_string (
54+ "grain_file_type" , "parquet" , "Type of data files. Supported: 'parquet', 'arrayrecord'."
55+ )
56+ _DATA_COLUMN = flags .DEFINE_string ("data_column" , "text" , "Column name to extract text from (used for arrayrecord)." )
3957_VOCAB_SIZE = flags .DEFINE_integer ("vocab_size" , 32_768 , "Vocab size" )
4058_MAX_CORPUS_CHARS = flags .DEFINE_integer ("max_corpus_chars" , 10_000_000 , "Max corpus chars" )
41- _ASSETS_PATH = flags .DEFINE_string ("assets_path" , MAXTEXT_ASSETS_ROOT , "Name to the dataset" )
42- _VOCAB_MODEL_NAME = flags .DEFINE_string ("vocab_model_name" , "tokenizer" , "Name to the dataset" )
59+ _ASSETS_PATH = flags .DEFINE_string ("assets_path" , MAXTEXT_ASSETS_ROOT , "Path to assets directory" )
60+ _VOCAB_MODEL_NAME = flags .DEFINE_string ("vocab_model_name" , "tokenizer" , "Output tokenizer model name" )
61+
62+
63+ def build_grain_iterator (data_file_pattern : str , data_file_type : str , data_keys : tuple [str , ...] = ("text" ,)) -> Iterator :
64+ """Build a grain iterator from a file pattern for tokenizer training.
65+
66+ Args:
67+ data_file_pattern: Glob pattern for data files (local path or gs://).
68+ data_file_type: One of 'arrayrecord' or 'parquet'.
69+ data_keys: Column names to extract from each example (used for arrayrecord).
70+
71+ Returns:
72+ A Python iterator yielding examples as dicts.
73+ """
74+ if data_file_pattern .startswith ("gs://" ):
75+ data_files = gcs_utils .gcs_glob_pattern (data_file_pattern )
76+ else :
77+ data_files = glob .glob (str (Path (data_file_pattern ).expanduser ().resolve ()))
78+ if not data_files :
79+ raise FileNotFoundError (f"No files found matching pattern: { data_file_pattern } " )
80+ logging .info ("Found %d files for tokenizer training." , len (data_files ))
81+
82+ if data_file_type == "parquet" :
83+ dataset = grain .MapDataset .source (data_files )
84+ dataset = dataset .map (grain .experimental .ParquetIterDataset )
85+ dataset = grain .experimental .InterleaveIterDataset (dataset , cycle_length = len (data_files ))
86+ return iter (dataset )
87+ elif data_file_type == "arrayrecord" :
88+ from maxtext .input_pipeline .protos import example_pb2 # pylint: disable=import-outside-toplevel
89+
90+ source = grain .ArrayRecordDataSource (data_files )
91+ dataset = grain .MapDataset .source (source )
92+
93+ def _parse_example (raw_bytes ):
94+ example = example_pb2 .Example ()
95+ example .ParseFromString (raw_bytes )
96+ features = example .features .feature
97+ parsed = {}
98+ for col in data_keys :
99+ if col in features :
100+ parsed [col ] = features [col ].bytes_list .value [0 ]
101+ return parsed
102+
103+ dataset = dataset .map (_parse_example )
104+ return iter (dataset )
105+ else :
106+ raise ValueError (f"Unsupported grain_file_type: { data_file_type !r} . Use 'parquet' or 'arrayrecord'." )
107+
43108
109+ def _dump_chars_to_textfile (dataset_iter : Iterator , maxchars : int = int (1e7 ), data_keys = ("text" ,)) -> tuple [str , int ]:
110+ """Write part of a grain dataset to lines in a text file.
44111
45- def _dump_chars_to_textfile (dataset : tf .data .Dataset , maxchars : int = int (1e7 ), data_keys = ("text" ,)) -> tuple [str , int ]:
46- """Write part of a TFDS sentence dataset to lines in a text file.
47112 Args:
48- dataset: tf.dataset containing string-data.
49- maxchars: int: approximate number of characters to save from dataset.
50- data_keys: tuple[str]: what keys in dataset to dump from.
113+ dataset_iter: Iterator yielding examples as dicts.
114+ maxchars: Approximate number of characters to save from dataset.
115+ data_keys: Keys in each example to dump.
116+
51117 Returns:
52- name of temp file with dataset bytes, exact number of characters dumped.
118+ Name of temp file with dataset bytes, exact number of characters dumped.
53119 """
54120 char_count = 0
55- ds_iter = dataset .as_numpy_iterator ()
56121 temp_dir = tempfile .gettempdir ()
57- with tempfile .NamedTemporaryFile (delete = False , prefix = os .path .join (temp_dir , "ds_chars" )) as outfp :
122+ with tempfile .NamedTemporaryFile (
123+ delete = False , prefix = os .path .join (temp_dir , "ds_chars" ), mode = "w" , encoding = "utf-8"
124+ ) as outfp :
58125 while char_count < maxchars :
59- example = next (ds_iter )
126+ example = next (dataset_iter )
60127 for k in data_keys :
61- line = example [k ] + b"\n "
128+ val = example [k ]
129+ if isinstance (val , bytes ):
130+ val = val .decode ("utf-8" )
131+ line = val + "\n "
62132 char_count += len (line )
63133 outfp .write (line )
64134 return outfp .name , char_count
65135
66136
67137def _train_sentencepiece (
68- dataset : tf . data . Dataset ,
138+ dataset_iter : Iterator ,
69139 * ,
70140 vocab_size : int ,
71141 maxchars : int = int (1e7 ),
@@ -74,25 +144,25 @@ def _train_sentencepiece(
74144 character_coverage : float = 1.0 ,
75145 data_keys = ("text" ,),
76146):
77- """Train SentencePiece tokenizer from subset of tf dataset.
147+ """Train SentencePiece tokenizer from subset of a grain dataset.
148+
78149 Args:
79- dataset: tf.dataset
80- vocab_size: int: size of vocab tokens to train.
81- maxchars: int: number of characters to use for sentencepiece training.
82- model_path: str: path of model file to save vocab model to.
83- model_type: str: type of sentencepiece vocab to train.
84- character_coverage: amount of characters covered by the model, good defaults
85- are 0.9995 for languages with rich character set like Japanese or Chinese
86- and 1.0 for other languages with small character set.
87- data_keys: tuple[str]: keys of dataset to use for training.
150+ dataset_iter: Iterator yielding examples as dicts.
151+ vocab_size: Size of vocab tokens to train.
152+ maxchars: Number of characters to use for sentencepiece training.
153+ model_path: Path to save vocab model to (local or gs://).
154+ model_type: Type of sentencepiece vocab to train.
155+ character_coverage: Amount of characters covered by the model.
156+ data_keys: Keys of dataset to use for training.
157+
88158 Returns:
89- path to the trained sentencepiece vocabulary model.
159+ Path to the trained sentencepiece vocabulary model.
90160 """
91161 if model_path .startswith ("gs://" ):
92162 abs_model_path = model_path
93163 else :
94164 abs_model_path = os .path .abspath (os .path .expanduser (model_path ))
95- fname , _ = _dump_chars_to_textfile (dataset , maxchars = maxchars , data_keys = data_keys )
165+ fname , _ = _dump_chars_to_textfile (dataset_iter , maxchars = maxchars , data_keys = data_keys )
96166 temp_dir = tempfile .gettempdir ()
97167 with tempfile .NamedTemporaryFile (delete = False , prefix = os .path .join (temp_dir , "sp_tmp" )) as model_fp :
98168 pass # we just want a prefix'd tmp-filename
@@ -107,32 +177,38 @@ def _train_sentencepiece(
107177 )
108178 SentencePieceTrainer .Train (argstr )
109179 if jax .process_index () == 0 :
110- # Use an intermediate filename that is renamed to the target name to address
111- # create and fill delays.
112- copy_rename_path = abs_model_path + ".rntmp"
113- tf .io .gfile .makedirs (os .path .dirname (abs_model_path ))
114- tf .io .gfile .copy (model_fp .name + ".model" , copy_rename_path , overwrite = True )
115- tf .io .gfile .rename (copy_rename_path , abs_model_path , overwrite = True )
116- logging .info ("copied %s to %s" , model_fp .name + ".model" , abs_model_path )
180+ if abs_model_path .startswith ("gs://" ):
181+ gcs_utils .upload_blob (abs_model_path , model_fp .name + ".model" )
182+ logging .info ("Uploaded %s to %s" , model_fp .name + ".model" , abs_model_path )
183+ else :
184+ parent = os .path .dirname (abs_model_path )
185+ if parent :
186+ os .makedirs (parent , exist_ok = True )
187+ shutil .copy (model_fp .name + ".model" , abs_model_path )
188+ logging .info ("Copied %s to %s" , model_fp .name + ".model" , abs_model_path )
117189 else :
118- while not tf .io .gfile .exists (abs_model_path ):
119- time .sleep (1 )
190+ if abs_model_path .startswith ("gs://" ):
191+ while not gcs_utils .gcs_path_exists (abs_model_path ):
192+ time .sleep (1 )
193+ else :
194+ while not os .path .exists (abs_model_path ):
195+ time .sleep (1 )
120196 time .sleep (1 )
121197 return abs_model_path
122198
123199
124200def train_tokenizer (
125- dataset : tf . data . Dataset ,
201+ dataset_iter : Iterator ,
126202 * ,
127203 vocab_path : str ,
128204 vocab_size : int ,
129205 max_corpus_chars : int ,
130206 data_keys : tuple [str ] = ("text" ,),
131207):
132- """tokenizer training function"""
208+ """Tokenizer training function. """
133209 logging .info ("SentencePiece vocab not found, building one from data." )
134210 vocab_path = _train_sentencepiece (
135- dataset ,
211+ dataset_iter ,
136212 vocab_size = vocab_size ,
137213 maxchars = max_corpus_chars ,
138214 model_path = vocab_path ,
@@ -143,19 +219,14 @@ def train_tokenizer(
143219
144220def main (argv ):
145221 del argv
146- flags .FLAGS (sys .argv )
147- os .environ ["TFDS_DATA_DIR" ] = _DATASET_PATH .value
148-
149- read_config = tfds .ReadConfig (
150- shuffle_seed = 0 ,
151- )
152- train_ds_builder = tfds .builder (_DATASET_NAME .value )
153- train_ds = train_ds_builder .as_dataset (split = "train" , read_config = read_config , shuffle_files = True )
222+ data_keys = (_DATA_COLUMN .value ,)
223+ dataset_iter = build_grain_iterator (_GRAIN_TRAIN_FILES .value , _GRAIN_FILE_TYPE .value , data_keys = data_keys )
154224 train_tokenizer (
155- train_ds ,
225+ dataset_iter ,
156226 vocab_path = os .path .join (_ASSETS_PATH .value , _VOCAB_MODEL_NAME .value ),
157227 vocab_size = _VOCAB_SIZE .value ,
158228 max_corpus_chars = _MAX_CORPUS_CHARS .value ,
229+ data_keys = data_keys ,
159230 )
160231
161232
0 commit comments