Skip to content

Commit bfec2c8

Browse files
committed
Added orbax saving and a new file for inference that utilizes the pipeline.
1 parent d05161d commit bfec2c8

5 files changed

Lines changed: 233 additions & 41 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,21 @@ def create_orbax_checkpoint_manager(
5757
max_logging.log(f"checkpoint dir: {checkpoint_dir}")
5858
p = epath.Path(checkpoint_dir)
5959

60-
item_names = (
61-
"unet_config",
62-
"vae_config",
63-
"text_encoder_config",
64-
"scheduler_config",
65-
"unet_state",
66-
"vae_state",
67-
"text_encoder_state",
68-
"tokenizer_config",
69-
)
60+
if checkpoint_type == FLUX_CHECKPOINT:
61+
item_names = ("flux_state", "flux_config",
62+
"vae_state", "vae_config",
63+
"scheduler", "scheduler_config")
64+
else:
65+
item_names = (
66+
"unet_config",
67+
"vae_config",
68+
"text_encoder_config",
69+
"scheduler_config",
70+
"unet_state",
71+
"vae_state",
72+
"text_encoder_state",
73+
"tokenizer_config",
74+
)
7075
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
7176
item_names += (
7277
"text_encoder_2_state",
@@ -140,6 +145,7 @@ def load_params_from_path(
140145

141146
ckpt_path = os.path.join(config.checkpoint_dir, str(step), checkpoint_item)
142147
ckpt_path = epath.Path(ckpt_path)
148+
ckpt_path = os.path.abspath(ckpt_path)
143149

144150
restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params)
145151
restored = ckptr.restore(

src/maxdiffusion/checkpointing/flux_checkpointer.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from abc import ABC
1818
from contextlib import nullcontext
1919
import functools
20+
import json
21+
import os
2022
import jax
2123
import jax.numpy as jnp
2224
from jax.sharding import Mesh
@@ -59,8 +61,10 @@ def __init__(self, config, checkpoint_type):
5961
self.mesh = Mesh(self.devices_array, self.config.mesh_axes)
6062
self.total_train_batch_size = self.config.total_train_batch_size
6163

64+
checkpoint_dir = os.path.abspath(self.config.checkpoint_dir)
65+
6266
self.checkpoint_manager = create_orbax_checkpoint_manager(
63-
self.config.checkpoint_dir,
67+
checkpoint_dir,
6468
enable_checkpointing=True,
6569
save_interval_steps=1,
6670
checkpoint_type=checkpoint_type,
@@ -117,7 +121,7 @@ def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=F
117121
config=self.config,
118122
mesh=self.mesh,
119123
weights_init_fn=weights_init_fn,
120-
model_params=params,
124+
model_params=params.get("flux_vae", None),
121125
checkpoint_manager=self.checkpoint_manager,
122126
checkpoint_item=checkpoint_item_name,
123127
training=is_training,
@@ -149,20 +153,35 @@ def save_checkpoint(self, train_step, pipeline, train_states):
149153
def config_to_json(model_or_config):
150154
return json.loads(model_or_config.to_json_string())
151155
items = {
152-
"config": ocp.args.JsonSave({"model_name": self.config.model_name}),
156+
"flux_config": ocp.args.JsonSave(config_to_json(pipeline.flux)),
157+
"vae_config": ocp.args.JsonSave(config_to_json(pipeline.vae)),
158+
"scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler))
153159
}
154160

155161
items[FLUX_STATE_KEY] = ocp.args.PyTreeSave(train_states[FLUX_STATE_KEY])
162+
items["vae_state"] = ocp.args.PyTreeSave(train_states["vae_state"])
163+
items["scheduler"] = ocp.args.PyTreeSave(train_states["scheduler"])
156164

157165
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
158166

159167
def load_params(self, step=None):
160168

161169
self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX
162170

163-
def load_flux_configs_from_orbax(self):
164-
# TODO - load configs from orbax
165-
return None
171+
def load_flux_configs_from_orbax(self, step):
172+
max_logging.log("Restoring stable diffusion configs")
173+
if step is None:
174+
step = self.checkpoint_manager.latest_step()
175+
if step is None:
176+
return None
177+
178+
restore_args = {
179+
"flux_config": ocp.args.JsonRestore(),
180+
"vae_config": ocp.args.JsonRestore(),
181+
"scheduler_config": ocp.args.JsonRestore(),
182+
}
183+
184+
return (self.checkpoint_manager.restore(step, args=ocp.args.Composite(**restore_args)), None)
166185

167186
def load_diffusers_checkpoint(self):
168187
flash_block_sizes = max_utils.get_flash_block_sizes(self.config)
@@ -238,12 +257,65 @@ def load_diffusers_checkpoint(self):
238257

239258
def load_checkpoint(self, step=None, scheduler_class=None):
240259

241-
model_configs = self.load_flux_configs_from_orbax()
260+
model_configs = self.load_flux_configs_from_orbax(step)
242261

243262
pipeline, params = None, {}
244263

245264
if model_configs:
246-
print("TODO - load configs from orbax")
265+
if jax.device_count() == jax.local_device_count():
266+
context = jax.default_device(jax.devices("cpu")[0])
267+
else:
268+
context = nullcontext()
269+
270+
with context:
271+
clip_encoder = FlaxCLIPTextModel.from_pretrained(
272+
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
273+
)
274+
clip_tokenizer = CLIPTokenizer.from_pretrained(
275+
self.config.clip_model_name_or_path,
276+
max_length=77,
277+
use_fast=True
278+
)
279+
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
280+
t5_tokenizer = AutoTokenizer.from_pretrained(
281+
self.config.t5xxl_model_name_or_path,
282+
max_length=self.config.max_sequence_length,
283+
use_fast=True
284+
)
285+
286+
vae = FlaxAutoencoderKL.from_config(
287+
model_configs[0]["vae_config"],
288+
dtype=self.config.activations_dtype,
289+
weights_dtype=self.config.weights_dtype,
290+
from_pt=self.config.from_pt,
291+
)
292+
293+
transformer = FluxTransformer2DModel.from_config(
294+
model_configs[0]["flux_config"],
295+
mesh=self.mesh,
296+
split_head_dim=self.config.split_head_dim,
297+
attention_kernel=self.config.attention,
298+
flash_block_sizes=max_utils.get_flash_block_sizes(self.config),
299+
dtype=self.config.activations_dtype,
300+
weights_dtype=self.config.weights_dtype,
301+
precision=max_utils.get_precision(self.config),
302+
from_pt=self.config.from_pt,
303+
)
304+
305+
pipeline = FluxPipeline(
306+
t5_encoder,
307+
clip_encoder,
308+
vae,
309+
t5_tokenizer,
310+
clip_tokenizer,
311+
transformer,
312+
None,
313+
dtype=self.config.activations_dtype,
314+
mesh=self.mesh,
315+
config=self.config,
316+
rng=self.rng
317+
)
318+
247319
else:
248320
pipeline, params = self.load_diffusers_checkpoint()
249321

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Sequence
18+
from absl import app
19+
from contextlib import ExitStack
20+
import functools
21+
import time
22+
import numpy as np
23+
from PIL import Image
24+
import jax
25+
26+
from maxdiffusion import pyconfig, max_logging, max_utils
27+
28+
from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer
29+
from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path
30+
from maxdiffusion.max_utils import setup_initial_state
31+
32+
def run(config):
33+
checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT")
34+
pipeline, params = checkpoint_loader.load_checkpoint()
35+
36+
if not params:
37+
## VAE
38+
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
39+
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
40+
pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False
41+
)
42+
# load unet params from orbax checkpoint
43+
vae_params = load_params_from_path(
44+
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state"
45+
)
46+
47+
vae_state = {"params": vae_params}
48+
49+
## Flux
50+
weights_init_fn = functools.partial(pipeline.flux.init_weights,
51+
rngs=checkpoint_loader.rng,
52+
max_sequence_length=config.max_sequence_length)
53+
54+
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
55+
pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False
56+
)
57+
# load unet params from orbax checkpoint
58+
flux_params = load_params_from_path(
59+
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state"
60+
)
61+
flux_state = {"params": flux_params}
62+
else:
63+
weights_init_fn = functools.partial(
64+
pipeline.flux.init_weights,
65+
rngs=checkpoint_loader.rng,
66+
max_sequence_length=config.max_sequence_length,
67+
eval_only=False
68+
)
69+
transformer_state, flux_state_shardings = setup_initial_state(
70+
model=pipeline.flux,
71+
tx=None,
72+
config=config,
73+
mesh=checkpoint_loader.mesh,
74+
weights_init_fn=weights_init_fn,
75+
model_params=None,
76+
training=False,
77+
)
78+
transformer_state = transformer_state.replace(params=params["flux_transformer_params"])
79+
transformer_state = jax.device_put(transformer_state, flux_state_shardings)
80+
81+
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
82+
vae_state, _ = setup_initial_state(
83+
model=pipeline.vae,
84+
tx=None,
85+
config=config,
86+
mesh=checkpoint_loader.mesh,
87+
weights_init_fn=weights_init_fn,
88+
model_params=params['flux_vae'],
89+
training=False,
90+
)
91+
92+
vae_state = {"params": vae_state.params}
93+
flux_state = {"params": transformer_state.params}
94+
95+
t0 = time.perf_counter()
96+
with ExitStack() as stack:
97+
imgs = pipeline(flux_params=flux_state,
98+
timesteps=50,
99+
vae_params=vae_state).block_until_ready()
100+
t1 = time.perf_counter()
101+
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
102+
103+
t0 = time.perf_counter()
104+
with ExitStack() as stack:
105+
imgs = pipeline(flux_params=flux_state,
106+
timesteps=50,
107+
vae_params=vae_state).block_until_ready()
108+
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
109+
t1 = time.perf_counter()
110+
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
111+
imgs = np.array(imgs)
112+
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
113+
imgs = np.transpose(imgs, (0, 2, 3, 1))
114+
imgs = np.uint8(imgs * 255)
115+
for i, image in enumerate(imgs):
116+
Image.fromarray(image).save(f"flux_{i}.png")
117+
118+
return imgs
119+
120+
121+
def main(argv: Sequence[str]) -> None:
122+
pyconfig.initialize(argv)
123+
run(pyconfig.config)
124+
125+
126+
if __name__ == "__main__":
127+
app.run(main)

src/maxdiffusion/pipelines/flux/flux_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,12 @@ def unpack(self, x: Array, height: int, width: int) -> Array:
102102
def vae_decode(self, latents, vae, state, config):
103103
img = self.unpack(x=latents, height=config.resolution, width=config.resolution)
104104
img = img / vae.config.scaling_factor + vae.config.shift_factor
105-
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample
105+
img = vae.apply({"params": state["params"]}, img, deterministic=True, method=vae.decode).sample
106106
return img
107107

108108
def vae_encode(self, latents, vae, state):
109109
img = vae.apply(
110-
{"params": state.params},
110+
{"params": state["params"]},
111111
latents,
112112
deterministic=True,
113113
method=vae.encode).latent_dist.mode()
@@ -297,7 +297,7 @@ def loop_body(
297297
t_prev = p_ts[step]
298298
t_vec = jnp.full((latents.shape[0],), t_curr, dtype=latents.dtype)
299299
pred = transformer.apply(
300-
{"params": state.params},
300+
{"params": state['params']},
301301
hidden_states=latents,
302302
img_ids=latent_image_ids,
303303
encoder_hidden_states=prompt_embeds,

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,7 @@ def __init__(self, config):
6565
raise ValueError("this script currently doesn't support training text_encoders")
6666

6767
def post_training_steps(self, pipeline, params, train_states, msg=""):
68-
imgs = pipeline(flux_params=train_states[FLUX_STATE_KEY],
69-
timesteps=50,
70-
vae_params=train_states["vae_state"])
71-
imgs = np.array(imgs)
72-
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
73-
imgs = np.transpose(imgs, (0, 2, 3, 1))
74-
imgs = np.uint8(imgs * 255)
75-
for i, image in enumerate(imgs):
76-
Image.fromarray(image).save(f"flux_{msg}_{i}.png")
68+
pass
7769

7870
def create_scheduler(self, pipeline, params):
7971
noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained(
@@ -113,7 +105,7 @@ def start_training(self):
113105

114106

115107
vae_state, vae_state_mesh_shardings = self.create_vae_state(
116-
pipeline=pipeline, params=params[FLUX_VAE_PARAMS_KEY], checkpoint_item_name=VAE_STATE_KEY, is_training=False
108+
pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False
117109
)
118110
train_states[VAE_STATE_KEY] = vae_state
119111
state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings
@@ -131,14 +123,13 @@ def start_training(self):
131123
flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
132124
# ambiguous here, but if params=None
133125
# Then its 1 of 2 scenarios:
134-
# 1. unet state will be loaded directly from orbax
135-
# 2. a new unet is being trained from scratch.
126+
# 1. flux state will be loaded directly from orbax
127+
# 2. a new flux is being trained from scratch.
136128
pipeline=pipeline,
137129
params=None, # Params are loaded inside create_flux_state
138130
checkpoint_item_name=FLUX_STATE_KEY,
139131
is_training=True,
140132
)
141-
flux_state = flux_state.replace(params=params[FLUX_TRANSFORMER_PARAMS_KEY])
142133
flux_state = jax.device_put(flux_state, flux_state_mesh_shardings)
143134
train_states[FLUX_STATE_KEY] = flux_state
144135
state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings
@@ -162,7 +153,7 @@ def start_training(self):
162153
)
163154
# 6. save final checkpoint
164155
# Hook
165-
#self.post_training_steps(pipeline, params, train_states, "after_training")
156+
self.post_training_steps(pipeline, params, train_states, "after_training")
166157

167158
def get_shaped_batch(self, config, pipeline=None):
168159
"""Return the shape of the batch - this is what eval_shape would return for the
@@ -408,13 +399,9 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
408399
if self.config.enable_profiler and step == last_profiling_step:
409400
max_utils.deactivate_profiler(self.config)
410401

411-
if self.config.write_metrics:
412-
write_metrics(
413-
writer, local_metrics_file, running_gcs_metrics, train_metric, self.config.max_train_steps - 1, self.config
414-
)
415-
416402
train_states[FLUX_STATE_KEY] = flux_state
417-
max_logging.log(f"Average time per step: {sum(times[2:], datetime.timedelta(0)) / len(times[2:])}")
403+
if len(times) > 0:
404+
max_logging.log(f"Average time per step: {sum(times[2:], datetime.timedelta(0)) / len(times[2:])}")
418405
if self.config.save_final_checkpoint:
419406
max_logging.log(f"Saving checkpoint for step {step}")
420407
self.save_checkpoint(step, pipeline, train_states)

0 commit comments

Comments
 (0)