1313# limitations under the License.
1414
1515""" Train tokenizer
16- Example usage: python3 -m MaxText.train_tokenizer --dataset_path=gs://maxtext-dataset --dataset_name=c4/en:3.0.1
16+ Example usage (parquet):
17+ python3 -m MaxText.train_tokenizer \
18+ --grain_train_files=gs://my-bucket/data/*.parquet \
19+ --grain_file_type=parquet
20+
21+ Example usage (arrayrecord):
22+ python3 -m MaxText.train_tokenizer \
23+ --grain_train_files=gs://my-bucket/data/*.arrayrecord \
24+ --grain_file_type=arrayrecord \
25+ --data_column=text
1726"""
1827
28+ import glob
1929import os
20- import sys
30+ import shutil
2131import tempfile
2232import time
33+ from collections .abc import Iterator
34+ from pathlib import Path
2335
2436from absl import app
2537from absl import flags
2840from sentencepiece import SentencePieceTrainer
2941
3042import jax
31-
32- import tensorflow as tf
33- import tensorflow_datasets as tfds
43+ import grain .python as grain
3444
3545from maxtext .utils .globals import MAXTEXT_ASSETS_ROOT
46+ from maxtext .utils import gcs_utils
47+
3648
37- _DATASET_PATH = flags .DEFINE_string ("dataset_path" , None , "Path to the dataset" , required = True )
38- _DATASET_NAME = flags .DEFINE_string ("dataset_name" , None , "Name to the dataset" , required = True )
49+ _GRAIN_TRAIN_FILES = flags .DEFINE_string (
50+ "grain_train_files" , None , "File pattern for training data (local or gs://)" , required = True
51+ )
52+ _GRAIN_FILE_TYPE = flags .DEFINE_string (
53+ "grain_file_type" , "parquet" , "Type of data files. Supported: 'parquet', 'arrayrecord'."
54+ )
55+ _DATA_COLUMN = flags .DEFINE_string ("data_column" , "text" , "Column name to extract text from (used for arrayrecord)." )
3956_VOCAB_SIZE = flags .DEFINE_integer ("vocab_size" , 32_768 , "Vocab size" )
4057_MAX_CORPUS_CHARS = flags .DEFINE_integer ("max_corpus_chars" , 10_000_000 , "Max corpus chars" )
41- _ASSETS_PATH = flags .DEFINE_string ("assets_path" , MAXTEXT_ASSETS_ROOT , "Name to the dataset" )
42- _VOCAB_MODEL_NAME = flags .DEFINE_string ("vocab_model_name" , "tokenizer" , "Name to the dataset" )
58+ _ASSETS_PATH = flags .DEFINE_string ("assets_path" , MAXTEXT_ASSETS_ROOT , "Path to assets directory" )
59+ _VOCAB_MODEL_NAME = flags .DEFINE_string ("vocab_model_name" , "tokenizer" , "Output tokenizer model name" )
60+
61+
62+ def build_grain_iterator (data_file_pattern : str , data_file_type : str , data_keys : tuple [str , ...] = ("text" ,)) -> Iterator :
63+ """Build a grain iterator from a file pattern for tokenizer training.
64+
65+ Args:
66+ data_file_pattern: Glob pattern for data files (local path or gs://).
67+ data_file_type: One of 'arrayrecord' or 'parquet'.
68+ data_keys: Column names to extract from each example (used for arrayrecord).
69+
70+ Returns:
71+ A Python iterator yielding examples as dicts.
72+ """
73+ if data_file_pattern .startswith ("gs://" ):
74+ data_files = gcs_utils .gcs_glob_pattern (data_file_pattern )
75+ else :
76+ data_files = glob .glob (str (Path (data_file_pattern ).expanduser ().resolve ()))
77+ if not data_files :
78+ raise FileNotFoundError (f"No files found matching pattern: { data_file_pattern } " )
79+ logging .info ("Found %d files for tokenizer training." , len (data_files ))
80+
81+ if data_file_type == "parquet" :
82+ dataset = grain .MapDataset .source (data_files )
83+ dataset = dataset .map (grain .experimental .ParquetIterDataset )
84+ dataset = grain .experimental .InterleaveIterDataset (dataset , cycle_length = len (data_files ))
85+ return iter (dataset )
86+ elif data_file_type == "arrayrecord" :
87+ from maxtext .input_pipeline .protos import example_pb2 # pylint: disable=import-outside-toplevel
88+
89+ source = grain .ArrayRecordDataSource (data_files )
90+ dataset = grain .MapDataset .source (source )
91+
92+ def _parse_example (raw_bytes ):
93+ example = example_pb2 .Example ()
94+ example .ParseFromString (raw_bytes )
95+ features = example .features .feature
96+ parsed = {}
97+ for col in data_keys :
98+ if col in features :
99+ parsed [col ] = features [col ].bytes_list .value [0 ]
100+ return parsed
101+
102+ dataset = dataset .map (_parse_example )
103+ return iter (dataset )
104+ else :
105+ raise ValueError (f"Unsupported grain_file_type: { data_file_type !r} . Use 'parquet' or 'arrayrecord'." )
106+
43107
108+ def _dump_chars_to_textfile (dataset_iter : Iterator , maxchars : int = int (1e7 ), data_keys = ("text" ,)) -> tuple [str , int ]:
109+ """Write part of a grain dataset to lines in a text file.
44110
45- def _dump_chars_to_textfile (dataset : tf .data .Dataset , maxchars : int = int (1e7 ), data_keys = ("text" ,)) -> tuple [str , int ]:
46- """Write part of a TFDS sentence dataset to lines in a text file.
47111 Args:
48- dataset: tf.dataset containing string-data.
49- maxchars: int: approximate number of characters to save from dataset.
50- data_keys: tuple[str]: what keys in dataset to dump from.
112+ dataset_iter: Iterator yielding examples as dicts.
113+ maxchars: Approximate number of characters to save from dataset.
114+ data_keys: Keys in each example to dump.
115+
51116 Returns:
52- name of temp file with dataset bytes, exact number of characters dumped.
117+ Name of temp file with dataset bytes, exact number of characters dumped.
53118 """
54119 char_count = 0
55- ds_iter = dataset .as_numpy_iterator ()
56120 temp_dir = tempfile .gettempdir ()
57- with tempfile .NamedTemporaryFile (delete = False , prefix = os .path .join (temp_dir , "ds_chars" )) as outfp :
121+ with tempfile .NamedTemporaryFile (
122+ delete = False , prefix = os .path .join (temp_dir , "ds_chars" ), mode = "w" , encoding = "utf-8"
123+ ) as outfp :
58124 while char_count < maxchars :
59- example = next (ds_iter )
125+ example = next (dataset_iter )
60126 for k in data_keys :
61- line = example [k ] + b"\n "
127+ val = example [k ]
128+ if isinstance (val , bytes ):
129+ val = val .decode ("utf-8" )
130+ line = val + "\n "
62131 char_count += len (line )
63132 outfp .write (line )
64133 return outfp .name , char_count
65134
66135
67136def _train_sentencepiece (
68- dataset : tf . data . Dataset ,
137+ dataset_iter : Iterator ,
69138 * ,
70139 vocab_size : int ,
71140 maxchars : int = int (1e7 ),
@@ -74,25 +143,25 @@ def _train_sentencepiece(
74143 character_coverage : float = 1.0 ,
75144 data_keys = ("text" ,),
76145):
77- """Train SentencePiece tokenizer from subset of tf dataset.
146+ """Train SentencePiece tokenizer from subset of a grain dataset.
147+
78148 Args:
79- dataset: tf.dataset
80- vocab_size: int: size of vocab tokens to train.
81- maxchars: int: number of characters to use for sentencepiece training.
82- model_path: str: path of model file to save vocab model to.
83- model_type: str: type of sentencepiece vocab to train.
84- character_coverage: amount of characters covered by the model, good defaults
85- are 0.9995 for languages with rich character set like Japanese or Chinese
86- and 1.0 for other languages with small character set.
87- data_keys: tuple[str]: keys of dataset to use for training.
149+ dataset_iter: Iterator yielding examples as dicts.
150+ vocab_size: Size of vocab tokens to train.
151+ maxchars: Number of characters to use for sentencepiece training.
152+ model_path: Path to save vocab model to (local or gs://).
153+ model_type: Type of sentencepiece vocab to train.
154+ character_coverage: Amount of characters covered by the model.
155+ data_keys: Keys of dataset to use for training.
156+
88157 Returns:
89- path to the trained sentencepiece vocabulary model.
158+ Path to the trained sentencepiece vocabulary model.
90159 """
91160 if model_path .startswith ("gs://" ):
92161 abs_model_path = model_path
93162 else :
94163 abs_model_path = os .path .abspath (os .path .expanduser (model_path ))
95- fname , _ = _dump_chars_to_textfile (dataset , maxchars = maxchars , data_keys = data_keys )
164+ fname , _ = _dump_chars_to_textfile (dataset_iter , maxchars = maxchars , data_keys = data_keys )
96165 temp_dir = tempfile .gettempdir ()
97166 with tempfile .NamedTemporaryFile (delete = False , prefix = os .path .join (temp_dir , "sp_tmp" )) as model_fp :
98167 pass # we just want a prefix'd tmp-filename
@@ -107,32 +176,38 @@ def _train_sentencepiece(
107176 )
108177 SentencePieceTrainer .Train (argstr )
109178 if jax .process_index () == 0 :
110- # Use an intermediate filename that is renamed to the target name to address
111- # create and fill delays.
112- copy_rename_path = abs_model_path + ".rntmp"
113- tf .io .gfile .makedirs (os .path .dirname (abs_model_path ))
114- tf .io .gfile .copy (model_fp .name + ".model" , copy_rename_path , overwrite = True )
115- tf .io .gfile .rename (copy_rename_path , abs_model_path , overwrite = True )
116- logging .info ("copied %s to %s" , model_fp .name + ".model" , abs_model_path )
179+ if abs_model_path .startswith ("gs://" ):
180+ gcs_utils .upload_blob (abs_model_path , model_fp .name + ".model" )
181+ logging .info ("Uploaded %s to %s" , model_fp .name + ".model" , abs_model_path )
182+ else :
183+ parent = os .path .dirname (abs_model_path )
184+ if parent :
185+ os .makedirs (parent , exist_ok = True )
186+ shutil .copy (model_fp .name + ".model" , abs_model_path )
187+ logging .info ("Copied %s to %s" , model_fp .name + ".model" , abs_model_path )
117188 else :
118- while not tf .io .gfile .exists (abs_model_path ):
119- time .sleep (1 )
189+ if abs_model_path .startswith ("gs://" ):
190+ while not gcs_utils .gcs_path_exists (abs_model_path ):
191+ time .sleep (1 )
192+ else :
193+ while not os .path .exists (abs_model_path ):
194+ time .sleep (1 )
120195 time .sleep (1 )
121196 return abs_model_path
122197
123198
124199def train_tokenizer (
125- dataset : tf . data . Dataset ,
200+ dataset_iter : Iterator ,
126201 * ,
127202 vocab_path : str ,
128203 vocab_size : int ,
129204 max_corpus_chars : int ,
130205 data_keys : tuple [str ] = ("text" ,),
131206):
132- """tokenizer training function"""
207+ """Tokenizer training function. """
133208 logging .info ("SentencePiece vocab not found, building one from data." )
134209 vocab_path = _train_sentencepiece (
135- dataset ,
210+ dataset_iter ,
136211 vocab_size = vocab_size ,
137212 maxchars = max_corpus_chars ,
138213 model_path = vocab_path ,
@@ -143,19 +218,14 @@ def train_tokenizer(
143218
144219def main (argv ):
145220 del argv
146- flags .FLAGS (sys .argv )
147- os .environ ["TFDS_DATA_DIR" ] = _DATASET_PATH .value
148-
149- read_config = tfds .ReadConfig (
150- shuffle_seed = 0 ,
151- )
152- train_ds_builder = tfds .builder (_DATASET_NAME .value )
153- train_ds = train_ds_builder .as_dataset (split = "train" , read_config = read_config , shuffle_files = True )
221+ data_keys = (_DATA_COLUMN .value ,)
222+ dataset_iter = build_grain_iterator (_GRAIN_TRAIN_FILES .value , _GRAIN_FILE_TYPE .value , data_keys = data_keys )
154223 train_tokenizer (
155- train_ds ,
224+ dataset_iter ,
156225 vocab_path = os .path .join (_ASSETS_PATH .value , _VOCAB_MODEL_NAME .value ),
157226 vocab_size = _VOCAB_SIZE .value ,
158227 max_corpus_chars = _MAX_CORPUS_CHARS .value ,
228+ data_keys = data_keys ,
159229 )
160230
161231
0 commit comments