Skip to content

Commit c11ad9d

Browse files
committed
Added training code, loss and results are stable
1 parent 296e956 commit c11ad9d

11 files changed

Lines changed: 1190 additions & 20 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
3535
STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
36+
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
3637

3738

3839
def create_orbax_checkpoint_manager(
@@ -66,7 +67,7 @@ def create_orbax_checkpoint_manager(
6667
"text_encoder_state",
6768
"tokenizer_config",
6869
)
69-
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
70+
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
7071
item_names += (
7172
"text_encoder_2_state",
7273
"text_encoder_2_config",
@@ -117,7 +118,7 @@ def load_stable_diffusion_configs(
117118
"tokenizer_config": orbax.checkpoint.args.JsonRestore(),
118119
}
119120

120-
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
121+
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
121122
restore_args["text_encoder_2_config"] = orbax.checkpoint.args.JsonRestore()
122123

123124
return (checkpoint_manager.restore(step, args=orbax.checkpoint.args.Composite(**restore_args)), None)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from abc import ABC
18+
from contextlib import nullcontext
19+
import os
20+
import json
21+
import functools
22+
import jax
23+
import jax.numpy as jnp
24+
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
25+
import orbax.checkpoint as ocp
26+
import grain.python as grain
27+
from maxdiffusion import (
28+
max_utils,
29+
FlaxAutoencoderKL,
30+
max_logging,
31+
)
32+
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
33+
from ..pipelines.flux.flux_pipeline import FluxPipeline
34+
35+
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)
36+
37+
from maxdiffusion.checkpointing.checkpointing_utils import (
38+
create_orbax_checkpoint_manager,
39+
load_stable_diffusion_configs,
40+
)
41+
from maxdiffusion.models.flux.util import load_flow_model
42+
43+
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
44+
_CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS"
45+
_CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX"
46+
47+
48+
class FluxCheckpointer(ABC):
49+
50+
def __init__(self, config, checkpoint_type):
51+
self.config = config
52+
self.checkpoint_type = checkpoint_type
53+
self.checkpoint_format = None
54+
55+
self.rng = jax.random.PRNGKey(self.config.seed)
56+
self.devices_array = max_utils.create_device_mesh(config)
57+
self.mesh = Mesh(self.devices_array, self.config.mesh_axes)
58+
self.total_train_batch_size = self.config.total_train_batch_size
59+
60+
self.checkpoint_manager = create_orbax_checkpoint_manager(
61+
self.config.checkpoint_dir,
62+
enable_checkpointing=True,
63+
save_interval_steps=1,
64+
checkpoint_type=checkpoint_type,
65+
dataset_type=config.dataset_type,
66+
)
67+
68+
def _create_optimizer(self, config, learning_rate):
69+
70+
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
71+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
72+
)
73+
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
74+
return tx, learning_rate_scheduler
75+
76+
def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training):
77+
transformer = pipeline.flux
78+
79+
tx, learning_rate_scheduler = None, None
80+
if is_training:
81+
learning_rate = self.config.learning_rate
82+
83+
tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate)
84+
85+
transformer_eval_params = transformer.init_weights(
86+
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
87+
)
88+
89+
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
90+
91+
weights_init_fn = functools.partial(pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length)
92+
flux_state, state_mesh_shardings = max_utils.setup_initial_state(
93+
model=pipeline.flux,
94+
tx=tx,
95+
config=self.config,
96+
mesh=self.mesh,
97+
weights_init_fn=weights_init_fn,
98+
model_params=None,
99+
checkpoint_manager=self.checkpoint_manager,
100+
checkpoint_item=checkpoint_item_name,
101+
training=is_training,
102+
)
103+
if not self.config.train_new_flux:
104+
flux_state = flux_state.replace(params=transformer_params)
105+
flux_state = jax.device_put(flux_state, state_mesh_shardings)
106+
return flux_state, state_mesh_shardings, learning_rate_scheduler
107+
108+
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
109+
110+
# Currently VAE training is not supported.
111+
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng)
112+
return max_utils.setup_initial_state(
113+
model=pipeline.vae,
114+
tx=None,
115+
config=self.config,
116+
mesh=self.mesh,
117+
weights_init_fn=weights_init_fn,
118+
model_params=params,
119+
checkpoint_manager=self.checkpoint_manager,
120+
checkpoint_item=checkpoint_item_name,
121+
training=is_training,
122+
)
123+
124+
def restore_data_iterator_state(self, data_iterator):
125+
if (
126+
self.config.dataset_type == "grain"
127+
and data_iterator is not None
128+
and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists()
129+
):
130+
max_logging.log("Restoring data iterator from checkpoint")
131+
restored = self.checkpoint_manager.restore(
132+
self.checkpoint_manager.latest_step(),
133+
args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)),
134+
)
135+
data_iterator.local_iterator = restored["iter"]
136+
else:
137+
max_logging.log("data iterator checkpoint not found")
138+
return data_iterator
139+
140+
def _get_pipeline_class(self):
141+
return FluxPipeline
142+
143+
def _set_checkpoint_format(self, checkpoint_format):
144+
self.checkpoint_format = checkpoint_format
145+
146+
def save_checkpoint(self, train_step, pipeline, train_states):
147+
items = {
148+
"config": ocp.args.JsonSave({"model_name": self.config.model_name}),
149+
}
150+
151+
items["flux_state"] = ocp.args.PyTreeSave(train_states["flux_state"])
152+
153+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
154+
155+
def load_params(self, step=None):
156+
157+
self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX
158+
159+
def load_checkpoint(self, step=None, scheduler_class=None):
160+
clip_encoder = FlaxCLIPTextModel.from_pretrained(
161+
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
162+
)
163+
clip_tokenizer = CLIPTokenizer.from_pretrained(
164+
self.config.clip_model_name_or_path, max_length=77, use_fast=True
165+
)
166+
167+
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
168+
t5_tokenizer = AutoTokenizer.from_pretrained(
169+
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
170+
)
171+
encoders_sharding = PositionalSharding(self.devices_array).replicate()
172+
partial_device_put_replicated = functools.partial(max_utils.device_put_replicated, sharding=encoders_sharding)
173+
clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_encoder.params)
174+
clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_encoder.params)
175+
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
176+
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)
177+
178+
179+
180+
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
181+
self.config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16"
182+
)
183+
184+
flash_block_sizes = max_utils.get_flash_block_sizes(self.config)
185+
# loading from pretrained here causes a crash when trying to compile the model
186+
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
187+
transformer = FluxTransformer2DModel.from_config(
188+
self.config.pretrained_model_name_or_path,
189+
subfolder="transformer",
190+
mesh=self.mesh,
191+
split_head_dim=self.config.split_head_dim,
192+
attention_kernel=self.config.attention,
193+
flash_block_sizes=flash_block_sizes,
194+
dtype=self.config.activations_dtype,
195+
weights_dtype=self.config.weights_dtype,
196+
precision=max_utils.get_precision(self.config),
197+
)
198+
199+
return FluxPipeline(t5_encoder,
200+
clip_encoder,
201+
vae,
202+
t5_tokenizer,
203+
clip_tokenizer,
204+
transformer,
205+
None,
206+
dtype=self.config.activations_dtype,
207+
mesh=self.mesh,
208+
config=self.config,
209+
rng=self.rng), vae_params
210+

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ norm_num_groups: 32
7373

7474
# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
7575
# else they will be loaded from pretrained_model_name_or_path
76-
train_new_unet: False
76+
train_new_flux: False
7777

7878
# train text_encoder - Currently not supported for SDXL
7979
train_text_encoder: False
@@ -111,7 +111,7 @@ diffusion_scheduler_config: {
111111
base_output_directory: ""
112112

113113
# Hardware
114-
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
114+
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'
115115

116116
# Parallelism
117117
mesh_axes: ['data', 'fsdp', 'tensor']
@@ -173,7 +173,7 @@ hf_train_files: ''
173173
hf_access_token: ''
174174
image_column: 'image'
175175
caption_column: 'text'
176-
resolution: 1024
176+
resolution: 512
177177
center_crop: False
178178
random_flip: False
179179
# If cache_latents_text_encoder_outputs is True
@@ -189,17 +189,17 @@ checkpoint_every: -1
189189
enable_single_replica_ckpt_restoring: False
190190

191191
# Training loop
192-
learning_rate: 4.e-7
192+
learning_rate: 1.e-5
193193
scale_lr: False
194194
max_train_samples: -1
195195
# max_train_steps takes priority over num_train_epochs.
196-
max_train_steps: 200
196+
max_train_steps: 1500
197197
num_train_epochs: 1
198198
seed: 0
199199
output_dir: 'sdxl-model-finetuned'
200200
per_device_batch_size: 1
201201

202-
warmup_steps_fraction: 0.0
202+
warmup_steps_fraction: 0.1
203203
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
204204

205205
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
@@ -209,7 +209,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set
209209
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
210210
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
211211
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
212-
adam_weight_decay: 1.e-2 # AdamW Weight decay
212+
adam_weight_decay: 0 # AdamW Weight decay
213213
max_grad_norm: 1.0
214214

215215
enable_profiler: False
@@ -219,14 +219,15 @@ skip_first_n_steps_for_profiler: 5
219219
profiler_steps: 10
220220

221221
# Generation parameters
222-
prompt: "A magical castle in the middle of a forest, artistic drawing"
223-
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
222+
prompt: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
223+
prompt_2: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
224224
negative_prompt: "purple, red"
225225
do_classifier_free_guidance: True
226226
guidance_scale: 3.5
227227
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
228228
guidance_rescale: 0.0
229229
num_inference_steps: 50
230+
save_final_checkpoint: False
230231

231232
# SDXL Lightning parameters
232233
lightning_from_pt: True

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ diffusion_scheduler_config: {
119119
base_output_directory: ""
120120

121121
# Hardware
122-
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
122+
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'
123123

124124
# Parallelism
125125
mesh_axes: ['data', 'fsdp', 'tensor']
@@ -234,7 +234,7 @@ do_classifier_free_guidance: True
234234
guidance_scale: 0.0
235235
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
236236
guidance_rescale: 0.0
237-
num_inference_steps: 4
237+
num_inference_steps: 50
238238

239239
# SDXL Lightning parameters
240240
lightning_from_pt: True

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,61 @@ def calculate_unet_tflops(config, pipeline, batch_size, rngs, train):
255255
/ jax.local_device_count()
256256
)
257257

258+
def get_dummy_flux_inputs(config, pipeline, batch_size):
259+
"""Returns randomly initialized flux inputs."""
260+
latents, latents_ids = pipeline.prepare_latents(
261+
batch_size=batch_size,
262+
num_channels_latents=pipeline.flux.in_channels // 4,
263+
height=config.resolution,
264+
width=config.resolution,
265+
vae_scale_factor=pipeline.vae_scale_factor,
266+
dtype=config.activations_dtype,
267+
rng=pipeline.rng
268+
)
269+
guidance_vec = jnp.asarray([config.guidance_scale] * batch_size, dtype=config.activations_dtype)
270+
271+
timesteps = jnp.ones((batch_size,), dtype=config.weights_dtype)
272+
t5_hidden_states_shape = (
273+
batch_size,
274+
config.max_sequence_length,
275+
4096,
276+
)
277+
t5_hidden_states = jnp.zeros(t5_hidden_states_shape, dtype=config.weights_dtype)
278+
t5_ids = jnp.zeros((batch_size, t5_hidden_states.shape[1], 3), dtype=config.weights_dtype)
279+
280+
clip_hidden_states_shape = (
281+
batch_size,
282+
768,
283+
)
284+
clip_hidden_states = jnp.zeros(clip_hidden_states_shape, dtype=config.weights_dtype)
285+
286+
return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states)
287+
288+
289+
def calculate_flux_tflops(config, pipeline, batch_size, rngs, train):
290+
"""
291+
Calculates jflux tflops.
292+
batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't
293+
cache the compilation when flash is enabled.
294+
"""
295+
296+
(latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) = get_dummy_flux_inputs(config, pipeline, batch_size)
297+
return (
298+
max_utils.calculate_model_tflops(
299+
pipeline.flux,
300+
rngs,
301+
train,
302+
hidden_states=latents,
303+
img_ids=latents_ids,
304+
encoder_hidden_states=t5_hidden_states,
305+
txt_ids=t5_ids,
306+
pooled_projections=clip_hidden_states,
307+
timestep=timesteps,
308+
guidance=guidance_vec,
309+
)
310+
/ jax.local_device_count()
311+
)
312+
258313

259314
def tokenize_captions(examples, caption_column, tokenizer, input_ids_key="input_ids", p_encode=None):
260315
"""Tokenize captions for sd1.x,sd2.x models."""
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_import_structure = { "pipeline_jflux" : "JfluxPipeline" }
2+
3+
from .flux_pipeline import (
4+
FluxPipeline,
5+
)

0 commit comments

Comments
 (0)