Skip to content

Commit f4b9042

Browse files
authored
support grain checkpoint (#133)
1 parent 269b621 commit f4b9042

5 files changed

Lines changed: 46 additions & 8 deletions

File tree

docs/data_README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag `
66
| -------- | ---------------- | --------------- | ----------------------- |
77
| HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset |
88
| tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset |
9-
| tfrecord | local/Cloud Storage | tfrecord | data are not loaded in memory but streamed from the saved location, good for big dataset |
9+
| tfrecord | local/Cloud Storage | TFRecord | data are not loaded in memory but streamed from the saved location, good for big dataset |
10+
| Grain | local/Cloud Storage | ArrayRecord (or any random access format) | data are not loaded in memory but streamed from the saved location, good for big dataset, supports global shuffle and data iterator checkpoint for determinism (see details in [doc](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-pipeline---for-determinism)) |
1011

1112
## Usage examples
1213

@@ -45,6 +46,12 @@ dataset_type: tfrecord
4546
train_data_dir: gs://<bucket>/<folder> # will use all TFRecord files under the directory
4647
```
4748

49+
### Grain (dataset_type=grain)
50+
```
51+
dataset_type: grain
52+
grain_train_files: gs://<bucket>/<folder>/*.arrayrecord # match the file pattern
53+
```
54+
4855
## Best Practice
4956
### Multihost Dataloading
5057
In multihost environment, if use a streaming type of input pipeline and the data format only supports sequential reads (dataset_type in (hf, tfrecord in MaxDiffusion)), the most performant way is to have each data file only accessed by one host, and each host access a subset of data files (shuffle is within the subset of files). This requires (# of data files) > (# of hosts loading data). We recommand users to reshard the dataset if this requirement is not met.

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax
2323
from jax.sharding import Mesh
2424
import orbax.checkpoint as ocp
25+
import grain.python as grain
2526
from maxdiffusion import (
2627
max_utils,
2728
FlaxStableDiffusionPipeline,
@@ -57,7 +58,11 @@ def __init__(self, config, checkpoint_type):
5758
self.total_train_batch_size = self.config.total_train_batch_size
5859

5960
self.checkpoint_manager = create_orbax_checkpoint_manager(
60-
self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type
61+
self.config.checkpoint_dir,
62+
enable_checkpointing=True,
63+
save_interval_steps=1,
64+
checkpoint_type=checkpoint_type,
65+
dataset_type=config.dataset_type,
6166
)
6267

6368
def _create_optimizer(self, config, learning_rate):
@@ -157,6 +162,22 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
157162
training=is_training,
158163
)
159164

165+
def restore_data_iterator_state(self, data_iterator):
166+
if (
167+
self.config.dataset_type == "grain"
168+
and data_iterator is not None
169+
and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists()
170+
):
171+
max_logging.log("Restoring data iterator from checkpoint")
172+
restored = self.checkpoint_manager.restore(
173+
self.checkpoint_manager.latest_step(),
174+
args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)),
175+
)
176+
data_iterator.local_iterator = restored["iter"]
177+
else:
178+
max_logging.log("data iterator checkpoint not found")
179+
return data_iterator
180+
160181
def _get_pipeline_class(self):
161182
if self.checkpoint_type == STABLE_DIFFUSION_CHECKPOINT:
162183
pipeline_class = FlaxStableDiffusionPipeline
@@ -212,7 +233,7 @@ def load_diffusers_checkpoint(self):
212233
params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params)
213234
return pipeline, params
214235

215-
def save_checkpoint(self, train_step, pipeline, params, train_states):
236+
def save_checkpoint(self, train_step, pipeline, params, train_states, data_iterator=None):
216237
def config_to_json(model_or_config):
217238
return json.loads(model_or_config.to_json_string())
218239

@@ -233,7 +254,8 @@ def config_to_json(model_or_config):
233254

234255
tokenizer_config = {"path": self.config.tokenizer_model_name_or_path}
235256
items["tokenizer_config"] = ocp.args.JsonSave(tokenizer_config)
236-
257+
if self.config.dataset_type == "grain":
258+
items["iter"] = grain.PyGrainCheckpointSave(data_iterator.local_iterator)
237259
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
238260

239261
def load_params(self, step=None):

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def create_orbax_checkpoint_manager(
4040
enable_checkpointing: bool,
4141
save_interval_steps,
4242
checkpoint_type: str,
43+
dataset_type: str = "tf",
4344
use_async: bool = True,
4445
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
4546
):
@@ -70,6 +71,8 @@ def create_orbax_checkpoint_manager(
7071
"text_encoder_2_state",
7172
"text_encoder_2_config",
7273
)
74+
if dataset_type == "grain":
75+
item_names += ("iter",)
7376

7477
print("item_names: ", item_names)
7578

src/maxdiffusion/trainers/base_stable_diffusion_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def start_training(self):
128128

129129
# Load dataset
130130
data_iterator = self.load_dataset(pipeline, params, train_states)
131+
if self.config.dataset_type == "grain":
132+
data_iterator = self.restore_data_iterator_state(data_iterator)
131133

132134
data_shardings = self.get_data_shardings()
133135
# Compile train_step

src/maxdiffusion/trainers/stable_diffusion_trainer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
import os
18+
import sys
1819
from functools import partial
1920
import datetime
2021
import time
@@ -211,7 +212,6 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
211212
unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(
212213
unet_state, vae_state, text_encoder_state, example_batch, train_rngs
213214
)
214-
samples_count = self.total_train_batch_size * (step + 1)
215215
new_time = datetime.datetime.now()
216216

217217
train_utils.record_scalar_metrics(
@@ -221,11 +221,15 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
221221
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
222222
last_step_completion = new_time
223223

224-
if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0:
224+
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
225225
train_states["unet_state"] = unet_state
226226
train_states["vae_state"] = vae_state
227227
train_states["text_encoder"] = text_encoder_state
228-
self.save_checkpoint(step, pipeline, params, train_states)
228+
self.save_checkpoint(step, pipeline, params, train_states, data_iterator)
229+
230+
if self.checkpoint_manager.reached_preemption(step):
231+
self.checkpoint_manager.wait_until_finished()
232+
sys.exit()
229233

230234
if self.config.enable_profiler and step == last_profiling_step:
231235
max_utils.deactivate_profiler(self.config)
@@ -239,7 +243,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
239243
train_states["vae_state"] = vae_state
240244
train_states["text_encoder"] = text_encoder_state
241245
# save the inference states of the last checkpoint so they can be easily loaded during gen.
242-
self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states)
246+
self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states, data_iterator)
243247
self.checkpoint_manager.wait_until_finished()
244248

245249

0 commit comments

Comments
 (0)