Skip to content

Commit d05161d

Browse files
jfacevedo-googleksikiric
authored andcommitted
refactor some code for similarity to sd trainers.
1 parent ee7d422 commit d05161d

2 files changed

Lines changed: 134 additions & 70 deletions

File tree

src/maxdiffusion/checkpointing/flux_checkpointer.py

Lines changed: 93 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616

1717
from abc import ABC
1818
from contextlib import nullcontext
19-
import os
20-
import json
2119
import functools
2220
import jax
2321
import jax.numpy as jnp
24-
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
22+
from jax.sharding import Mesh
2523
import orbax.checkpoint as ocp
2624
import grain.python as grain
2725
from maxdiffusion import (
@@ -35,15 +33,19 @@
3533
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)
3634

3735
from maxdiffusion.checkpointing.checkpointing_utils import (
38-
create_orbax_checkpoint_manager,
39-
load_stable_diffusion_configs,
36+
create_orbax_checkpoint_manager
4037
)
4138
from maxdiffusion.models.flux.util import load_flow_model
4239

4340
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
44-
_CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS"
4541
_CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX"
4642

43+
FLUX_STATE_KEY = "flux_state"
44+
FLUX_TRANSFORMER_PARAMS_KEY = "flux_transformer_params"
45+
FLUX_STATE_SHARDINGS_KEY = "flux_state_shardings"
46+
FLUX_VAE_PARAMS_KEY = "flux_vae"
47+
VAE_STATE_KEY = "vae_state"
48+
VAE_STATE_SHARDINGS_KEY = "vae_state_shardings"
4749

4850
class FluxCheckpointer(ABC):
4951

@@ -144,67 +146,106 @@ def _set_checkpoint_format(self, checkpoint_format):
144146
self.checkpoint_format = checkpoint_format
145147

146148
def save_checkpoint(self, train_step, pipeline, train_states):
149+
def config_to_json(model_or_config):
150+
return json.loads(model_or_config.to_json_string())
147151
items = {
148152
"config": ocp.args.JsonSave({"model_name": self.config.model_name}),
149153
}
150154

151-
items["flux_state"] = ocp.args.PyTreeSave(train_states["flux_state"])
155+
items[FLUX_STATE_KEY] = ocp.args.PyTreeSave(train_states[FLUX_STATE_KEY])
152156

153157
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
154158

155159
def load_params(self, step=None):
156160

157161
self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX
162+
163+
def load_flux_configs_from_orbax(self):
164+
# TODO - load configs from orbax
165+
return None
158166

159-
def load_checkpoint(self, step=None, scheduler_class=None):
160-
clip_encoder = FlaxCLIPTextModel.from_pretrained(
161-
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
162-
)
163-
clip_tokenizer = CLIPTokenizer.from_pretrained(
164-
self.config.clip_model_name_or_path, max_length=77, use_fast=True
165-
)
167+
def load_diffusers_checkpoint(self):
168+
flash_block_sizes = max_utils.get_flash_block_sizes(self.config)
166169

167-
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
168-
t5_tokenizer = AutoTokenizer.from_pretrained(
169-
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
170+
if jax.device_count() == jax.local_device_count():
171+
context = jax.default_device(jax.devices("cpu")[0])
172+
else:
173+
context = nullcontext()
174+
175+
with context:
176+
clip_encoder = FlaxCLIPTextModel.from_pretrained(
177+
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
178+
)
179+
clip_tokenizer = CLIPTokenizer.from_pretrained(
180+
self.config.clip_model_name_or_path,
181+
max_length=77,
182+
use_fast=True
183+
)
184+
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
185+
t5_tokenizer = AutoTokenizer.from_pretrained(
186+
self.config.t5xxl_model_name_or_path,
187+
max_length=self.config.max_sequence_length,
188+
use_fast=True
189+
)
190+
191+
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
192+
self.config.pretrained_model_name_or_path,
193+
subfolder="vae",
194+
from_pt=True,
195+
use_safetensors=True,
196+
dtype=self.config.weights_dtype
197+
)
198+
199+
# loading from pretrained here causes a crash when trying to compile the model
200+
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
201+
transformer = FluxTransformer2DModel.from_config(
202+
self.config.pretrained_model_name_or_path,
203+
subfolder="transformer",
204+
mesh=self.mesh,
205+
split_head_dim=self.config.split_head_dim,
206+
attention_kernel=self.config.attention,
207+
flash_block_sizes=flash_block_sizes,
208+
dtype=self.config.activations_dtype,
209+
weights_dtype=self.config.weights_dtype,
210+
precision=max_utils.get_precision(self.config),
211+
)
212+
transformer_eval_params = transformer.init_weights(
213+
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
214+
)
215+
216+
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
217+
218+
pipeline = FluxPipeline(
219+
t5_encoder,
220+
clip_encoder,
221+
vae,
222+
t5_tokenizer,
223+
clip_tokenizer,
224+
transformer,
225+
None,
226+
dtype=self.config.activations_dtype,
227+
mesh=self.mesh,
228+
config=self.config,
229+
rng=self.rng
170230
)
171-
encoders_sharding = PositionalSharding(self.devices_array).replicate()
172-
partial_device_put_replicated = functools.partial(max_utils.device_put_replicated, sharding=encoders_sharding)
173-
clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_encoder.params)
174-
clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_encoder.params)
175-
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
176-
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)
177231

232+
params = {
233+
FLUX_VAE_PARAMS_KEY : vae_params,
234+
FLUX_TRANSFORMER_PARAMS_KEY : transformer_params
235+
}
178236

237+
return pipeline, params
179238

180-
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
181-
self.config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16"
182-
)
239+
def load_checkpoint(self, step=None, scheduler_class=None):
183240

184-
flash_block_sizes = max_utils.get_flash_block_sizes(self.config)
185-
# loading from pretrained here causes a crash when trying to compile the model
186-
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
187-
transformer = FluxTransformer2DModel.from_config(
188-
self.config.pretrained_model_name_or_path,
189-
subfolder="transformer",
190-
mesh=self.mesh,
191-
split_head_dim=self.config.split_head_dim,
192-
attention_kernel=self.config.attention,
193-
flash_block_sizes=flash_block_sizes,
194-
dtype=self.config.activations_dtype,
195-
weights_dtype=self.config.weights_dtype,
196-
precision=max_utils.get_precision(self.config),
197-
)
198-
199-
return FluxPipeline(t5_encoder,
200-
clip_encoder,
201-
vae,
202-
t5_tokenizer,
203-
clip_tokenizer,
204-
transformer,
205-
None,
206-
dtype=self.config.activations_dtype,
207-
mesh=self.mesh,
208-
config=self.config,
209-
rng=self.rng), vae_params
241+
model_configs = self.load_flux_configs_from_orbax()
242+
243+
pipeline, params = None, {}
244+
245+
if model_configs:
246+
print("TODO - load configs from orbax")
247+
else:
248+
pipeline, params = self.load_diffusers_checkpoint()
249+
250+
return pipeline, params
210251

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,17 @@
2222
import jax
2323
import optax
2424
import jax.numpy as jnp
25-
from jax.sharding import PartitionSpec as P
25+
from jax.sharding import PositionalSharding, PartitionSpec as P
2626
from flax.linen import partitioning as nn_partitioning
27-
from maxdiffusion.checkpointing.flux_checkpointer import (FluxCheckpointer, FLUX_CHECKPOINT)
27+
from maxdiffusion.checkpointing.flux_checkpointer import (
28+
FluxCheckpointer,
29+
FLUX_CHECKPOINT,
30+
FLUX_TRANSFORMER_PARAMS_KEY,
31+
FLUX_STATE_KEY,
32+
FLUX_STATE_SHARDINGS_KEY,
33+
FLUX_VAE_PARAMS_KEY,
34+
VAE_STATE_KEY,
35+
VAE_STATE_SHARDINGS_KEY)
2836

2937
from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator)
3038

@@ -57,7 +65,7 @@ def __init__(self, config):
5765
raise ValueError("this script currently doesn't support training text_encoders")
5866

5967
def post_training_steps(self, pipeline, params, train_states, msg=""):
60-
imgs = pipeline(flux_params=train_states["flux_state"],
68+
imgs = pipeline(flux_params=train_states[FLUX_STATE_KEY],
6169
timesteps=50,
6270
vae_params=train_states["vae_state"])
6371
imgs = np.array(imgs)
@@ -94,11 +102,21 @@ def start_training(self):
94102
# create train states
95103
train_states = {}
96104
state_shardings = {}
105+
106+
# move params to accelerator
107+
encoders_sharding = PositionalSharding(self.devices_array).replicate()
108+
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
109+
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
110+
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
111+
pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params)
112+
pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params)
113+
114+
97115
vae_state, vae_state_mesh_shardings = self.create_vae_state(
98-
pipeline=pipeline, params=params, checkpoint_item_name="vae_state", is_training=False
116+
pipeline=pipeline, params=params[FLUX_VAE_PARAMS_KEY], checkpoint_item_name=VAE_STATE_KEY, is_training=False
99117
)
100-
train_states["vae_state"] = vae_state
101-
state_shardings["vae_state_shardings"] = vae_state_mesh_shardings
118+
train_states[VAE_STATE_KEY] = vae_state
119+
state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings
102120

103121
# Load dataset
104122
data_iterator = self.load_dataset(pipeline, params, train_states)
@@ -107,18 +125,23 @@ def start_training(self):
107125

108126
# don't need this anymore, clear some memory.
109127
del pipeline.t5_encoder
128+
129+
# evaluate shapes
130+
110131
flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
111-
# ambiguous here, but if self.params.get("unet") doesn't exist
132+
# ambiguous here, but if params=None
112133
# Then its 1 of 2 scenarios:
113134
# 1. unet state will be loaded directly from orbax
114135
# 2. a new unet is being trained from scratch.
115136
pipeline=pipeline,
116137
params=None, # Params are loaded inside create_flux_state
117-
checkpoint_item_name="flux_state",
138+
checkpoint_item_name=FLUX_STATE_KEY,
118139
is_training=True,
119140
)
120-
train_states["flux_state"] = flux_state
121-
state_shardings["flux_state_shardings"] = flux_state_mesh_shardings
141+
flux_state = flux_state.replace(params=params[FLUX_TRANSFORMER_PARAMS_KEY])
142+
flux_state = jax.device_put(flux_state, flux_state_mesh_shardings)
143+
train_states[FLUX_STATE_KEY] = flux_state
144+
state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings
122145
#self.post_training_steps(pipeline, params, train_states, msg="before_training")
123146

124147
# Create scheduler
@@ -320,15 +343,15 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da
320343
max_logging.log("Precompiling...")
321344
s = time.time()
322345
dummy_batch = self.get_shaped_batch(self.config, pipeline)
323-
p_train_step = p_train_step.lower(train_states["flux_state"], dummy_batch, train_rngs)
346+
p_train_step = p_train_step.lower(train_states[FLUX_STATE_KEY], dummy_batch, train_rngs)
324347
p_train_step = p_train_step.compile()
325348
max_logging.log(f"Compile time: {(time.time() - s )}")
326349
return p_train_step
327350

328351
def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler):
329352

330353
writer = max_utils.initialize_summary_writer(self.config)
331-
flux_state = train_states["flux_state"]
354+
flux_state = train_states[FLUX_STATE_KEY]
332355
num_model_parameters = max_utils.calculate_num_params_from_pytree(flux_state.params)
333356

334357
max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer)
@@ -352,7 +375,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
352375
last_profiling_step = np.clip(
353376
first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1
354377
)
355-
start_step = get_first_step(train_states["flux_state"])
378+
start_step = get_first_step(train_states[FLUX_STATE_KEY])
356379
_, train_rngs = jax.random.split(self.rng)
357380
times = []
358381
for step in np.arange(start_step, self.config.max_train_steps):
@@ -379,7 +402,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
379402

380403
if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0:
381404
max_logging.log(f"Saving checkpoint for step {step}")
382-
train_states["flux_state"] = flux_state
405+
train_states[FLUX_STATE_KEY] = flux_state
383406
self.save_checkpoint(step, pipeline, train_states)
384407

385408
if self.config.enable_profiler and step == last_profiling_step:
@@ -390,7 +413,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
390413
writer, local_metrics_file, running_gcs_metrics, train_metric, self.config.max_train_steps - 1, self.config
391414
)
392415

393-
train_states["flux_state"] = flux_state
416+
train_states[FLUX_STATE_KEY] = flux_state
394417
max_logging.log(f"Average time per step: {sum(times[2:], datetime.timedelta(0)) / len(times[2:])}")
395418
if self.config.save_final_checkpoint:
396419
max_logging.log(f"Saving checkpoint for step {step}")
@@ -402,7 +425,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
402425
def _train_step(flux_state, batch, train_rng, guidance_vec, pipeline, scheduler, config):
403426
_, gen_dummy_rng = jax.random.split(train_rng)
404427
sample_rng, timestep_bias_rng, new_train_rng = jax.random.split(gen_dummy_rng, 3)
405-
state_params = {"flux_state": flux_state.params}
428+
state_params = {FLUX_STATE_KEY: flux_state.params}
406429

407430
def compute_loss(state_params):
408431
latents = batch["pixel_values"]
@@ -424,7 +447,7 @@ def compute_loss(state_params):
424447
noisy_latents = pipeline.scheduler.add_noise(scheduler, latents, noise, timesteps, flux=True)
425448

426449
model_pred = pipeline.flux.apply(
427-
{"params": state_params["flux_state"]},
450+
{"params": state_params[FLUX_STATE_KEY]},
428451
hidden_states=noisy_latents,
429452
img_ids=img_ids,
430453
encoder_hidden_states=text_embeds,
@@ -444,7 +467,7 @@ def compute_loss(state_params):
444467
grad_fn = jax.value_and_grad(compute_loss)
445468
loss, grad = grad_fn(state_params)
446469

447-
new_state = flux_state.apply_gradients(grads=grad["flux_state"])
470+
new_state = flux_state.apply_gradients(grads=grad[FLUX_STATE_KEY])
448471

449472
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
450473

0 commit comments

Comments
 (0)