|
17 | 17 | from abc import ABC |
18 | 18 | from contextlib import nullcontext |
19 | 19 | import functools |
| 20 | +import json |
| 21 | +import os |
20 | 22 | import jax |
21 | 23 | import jax.numpy as jnp |
22 | 24 | from jax.sharding import Mesh |
@@ -59,8 +61,10 @@ def __init__(self, config, checkpoint_type): |
59 | 61 | self.mesh = Mesh(self.devices_array, self.config.mesh_axes) |
60 | 62 | self.total_train_batch_size = self.config.total_train_batch_size |
61 | 63 |
|
| 64 | + checkpoint_dir = os.path.abspath(self.config.checkpoint_dir) |
| 65 | + |
62 | 66 | self.checkpoint_manager = create_orbax_checkpoint_manager( |
63 | | - self.config.checkpoint_dir, |
| 67 | + checkpoint_dir, |
64 | 68 | enable_checkpointing=True, |
65 | 69 | save_interval_steps=1, |
66 | 70 | checkpoint_type=checkpoint_type, |
@@ -117,7 +121,7 @@ def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=F |
117 | 121 | config=self.config, |
118 | 122 | mesh=self.mesh, |
119 | 123 | weights_init_fn=weights_init_fn, |
120 | | - model_params=params, |
| 124 | + model_params=params.get("flux_vae", None), |
121 | 125 | checkpoint_manager=self.checkpoint_manager, |
122 | 126 | checkpoint_item=checkpoint_item_name, |
123 | 127 | training=is_training, |
@@ -149,20 +153,35 @@ def save_checkpoint(self, train_step, pipeline, train_states): |
149 | 153 | def config_to_json(model_or_config): |
150 | 154 | return json.loads(model_or_config.to_json_string()) |
151 | 155 | items = { |
152 | | - "config": ocp.args.JsonSave({"model_name": self.config.model_name}), |
| 156 | + "flux_config": ocp.args.JsonSave(config_to_json(pipeline.flux)), |
| 157 | + "vae_config": ocp.args.JsonSave(config_to_json(pipeline.vae)), |
| 158 | + "scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)) |
153 | 159 | } |
154 | 160 |
|
155 | 161 | items[FLUX_STATE_KEY] = ocp.args.PyTreeSave(train_states[FLUX_STATE_KEY]) |
| 162 | + items["vae_state"] = ocp.args.PyTreeSave(train_states["vae_state"]) |
| 163 | + items["scheduler"] = ocp.args.PyTreeSave(train_states["scheduler"]) |
156 | 164 |
|
157 | 165 | self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) |
158 | 166 |
|
159 | 167 | def load_params(self, step=None): |
160 | 168 |
|
161 | 169 | self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX |
162 | 170 |
|
163 | | - def load_flux_configs_from_orbax(self): |
164 | | - # TODO - load configs from orbax |
165 | | - return None |
| 171 | + def load_flux_configs_from_orbax(self, step): |
| 172 | + max_logging.log("Restoring stable diffusion configs") |
| 173 | + if step is None: |
| 174 | + step = self.checkpoint_manager.latest_step() |
| 175 | + if step is None: |
| 176 | + return None |
| 177 | + |
| 178 | + restore_args = { |
| 179 | + "flux_config": ocp.args.JsonRestore(), |
| 180 | + "vae_config": ocp.args.JsonRestore(), |
| 181 | + "scheduler_config": ocp.args.JsonRestore(), |
| 182 | + } |
| 183 | + |
| 184 | + return (self.checkpoint_manager.restore(step, args=ocp.args.Composite(**restore_args)), None) |
166 | 185 |
|
167 | 186 | def load_diffusers_checkpoint(self): |
168 | 187 | flash_block_sizes = max_utils.get_flash_block_sizes(self.config) |
@@ -238,12 +257,65 @@ def load_diffusers_checkpoint(self): |
238 | 257 |
|
239 | 258 | def load_checkpoint(self, step=None, scheduler_class=None): |
240 | 259 |
|
241 | | - model_configs = self.load_flux_configs_from_orbax() |
| 260 | + model_configs = self.load_flux_configs_from_orbax(step) |
242 | 261 |
|
243 | 262 | pipeline, params = None, {} |
244 | 263 |
|
245 | 264 | if model_configs: |
246 | | - print("TODO - load configs from orbax") |
| 265 | + if jax.device_count() == jax.local_device_count(): |
| 266 | + context = jax.default_device(jax.devices("cpu")[0]) |
| 267 | + else: |
| 268 | + context = nullcontext() |
| 269 | + |
| 270 | + with context: |
| 271 | + clip_encoder = FlaxCLIPTextModel.from_pretrained( |
| 272 | + self.config.clip_model_name_or_path, dtype=self.config.weights_dtype |
| 273 | + ) |
| 274 | + clip_tokenizer = CLIPTokenizer.from_pretrained( |
| 275 | + self.config.clip_model_name_or_path, |
| 276 | + max_length=77, |
| 277 | + use_fast=True |
| 278 | + ) |
| 279 | + t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) |
| 280 | + t5_tokenizer = AutoTokenizer.from_pretrained( |
| 281 | + self.config.t5xxl_model_name_or_path, |
| 282 | + max_length=self.config.max_sequence_length, |
| 283 | + use_fast=True |
| 284 | + ) |
| 285 | + |
| 286 | + vae = FlaxAutoencoderKL.from_config( |
| 287 | + model_configs[0]["vae_config"], |
| 288 | + dtype=self.config.activations_dtype, |
| 289 | + weights_dtype=self.config.weights_dtype, |
| 290 | + from_pt=self.config.from_pt, |
| 291 | + ) |
| 292 | + |
| 293 | + transformer = FluxTransformer2DModel.from_config( |
| 294 | + model_configs[0]["flux_config"], |
| 295 | + mesh=self.mesh, |
| 296 | + split_head_dim=self.config.split_head_dim, |
| 297 | + attention_kernel=self.config.attention, |
| 298 | + flash_block_sizes=max_utils.get_flash_block_sizes(self.config), |
| 299 | + dtype=self.config.activations_dtype, |
| 300 | + weights_dtype=self.config.weights_dtype, |
| 301 | + precision=max_utils.get_precision(self.config), |
| 302 | + from_pt=self.config.from_pt, |
| 303 | + ) |
| 304 | + |
| 305 | + pipeline = FluxPipeline( |
| 306 | + t5_encoder, |
| 307 | + clip_encoder, |
| 308 | + vae, |
| 309 | + t5_tokenizer, |
| 310 | + clip_tokenizer, |
| 311 | + transformer, |
| 312 | + None, |
| 313 | + dtype=self.config.activations_dtype, |
| 314 | + mesh=self.mesh, |
| 315 | + config=self.config, |
| 316 | + rng=self.rng |
| 317 | + ) |
| 318 | + |
247 | 319 | else: |
248 | 320 | pipeline, params = self.load_diffusers_checkpoint() |
249 | 321 |
|
|
0 commit comments