Skip to content

Commit ba2d028

Browse files
committed
Added training code, loss and results are stable
1 parent a774fb1 commit ba2d028

11 files changed

Lines changed: 1192 additions & 22 deletions

File tree

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
3535
STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
36+
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
3637

3738

3839
def create_orbax_checkpoint_manager(
@@ -66,7 +67,7 @@ def create_orbax_checkpoint_manager(
6667
"text_encoder_state",
6768
"tokenizer_config",
6869
)
69-
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
70+
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
7071
item_names += (
7172
"text_encoder_2_state",
7273
"text_encoder_2_config",
@@ -117,7 +118,7 @@ def load_stable_diffusion_configs(
117118
"tokenizer_config": orbax.checkpoint.args.JsonRestore(),
118119
}
119120

120-
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
121+
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
121122
restore_args["text_encoder_2_config"] = orbax.checkpoint.args.JsonRestore()
122123

123124
return (checkpoint_manager.restore(step, args=orbax.checkpoint.args.Composite(**restore_args)), None)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from abc import ABC
18+
from contextlib import nullcontext
19+
import os
20+
import json
21+
import functools
22+
import jax
23+
import jax.numpy as jnp
24+
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
25+
import orbax.checkpoint as ocp
26+
import grain.python as grain
27+
from maxdiffusion import (
28+
max_utils,
29+
FlaxAutoencoderKL,
30+
max_logging,
31+
)
32+
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
33+
from ..pipelines.flux.flux_pipeline import FluxPipeline
34+
35+
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)
36+
37+
from maxdiffusion.checkpointing.checkpointing_utils import (
38+
create_orbax_checkpoint_manager,
39+
load_stable_diffusion_configs,
40+
)
41+
from maxdiffusion.models.flux.util import load_flow_model
42+
43+
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
44+
_CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS"
45+
_CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX"
46+
47+
48+
class FluxCheckpointer(ABC):
49+
50+
def __init__(self, config, checkpoint_type):
51+
self.config = config
52+
self.checkpoint_type = checkpoint_type
53+
self.checkpoint_format = None
54+
55+
self.rng = jax.random.PRNGKey(self.config.seed)
56+
self.devices_array = max_utils.create_device_mesh(config)
57+
self.mesh = Mesh(self.devices_array, self.config.mesh_axes)
58+
self.total_train_batch_size = self.config.total_train_batch_size
59+
60+
self.checkpoint_manager = create_orbax_checkpoint_manager(
61+
self.config.checkpoint_dir,
62+
enable_checkpointing=True,
63+
save_interval_steps=1,
64+
checkpoint_type=checkpoint_type,
65+
dataset_type=config.dataset_type,
66+
)
67+
68+
def _create_optimizer(self, config, learning_rate):
69+
70+
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
71+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
72+
)
73+
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
74+
return tx, learning_rate_scheduler
75+
76+
def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training):
77+
transformer = pipeline.flux
78+
79+
tx, learning_rate_scheduler = None, None
80+
if is_training:
81+
learning_rate = self.config.learning_rate
82+
83+
tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate)
84+
85+
transformer_eval_params = transformer.init_weights(
86+
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
87+
)
88+
89+
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
90+
91+
weights_init_fn = functools.partial(pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length)
92+
flux_state, state_mesh_shardings = max_utils.setup_initial_state(
93+
model=pipeline.flux,
94+
tx=tx,
95+
config=self.config,
96+
mesh=self.mesh,
97+
weights_init_fn=weights_init_fn,
98+
model_params=None,
99+
checkpoint_manager=self.checkpoint_manager,
100+
checkpoint_item=checkpoint_item_name,
101+
training=is_training,
102+
)
103+
if not self.config.train_new_flux:
104+
flux_state = flux_state.replace(params=transformer_params)
105+
flux_state = jax.device_put(flux_state, state_mesh_shardings)
106+
return flux_state, state_mesh_shardings, learning_rate_scheduler
107+
108+
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
109+
110+
# Currently VAE training is not supported.
111+
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng)
112+
return max_utils.setup_initial_state(
113+
model=pipeline.vae,
114+
tx=None,
115+
config=self.config,
116+
mesh=self.mesh,
117+
weights_init_fn=weights_init_fn,
118+
model_params=params,
119+
checkpoint_manager=self.checkpoint_manager,
120+
checkpoint_item=checkpoint_item_name,
121+
training=is_training,
122+
)
123+
124+
def restore_data_iterator_state(self, data_iterator):
125+
if (
126+
self.config.dataset_type == "grain"
127+
and data_iterator is not None
128+
and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists()
129+
):
130+
max_logging.log("Restoring data iterator from checkpoint")
131+
restored = self.checkpoint_manager.restore(
132+
self.checkpoint_manager.latest_step(),
133+
args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)),
134+
)
135+
data_iterator.local_iterator = restored["iter"]
136+
else:
137+
max_logging.log("data iterator checkpoint not found")
138+
return data_iterator
139+
140+
def _get_pipeline_class(self):
141+
return FluxPipeline
142+
143+
def _set_checkpoint_format(self, checkpoint_format):
144+
self.checkpoint_format = checkpoint_format
145+
146+
def save_checkpoint(self, train_step, pipeline, train_states):
147+
items = {
148+
"config": ocp.args.JsonSave({"model_name": self.config.model_name}),
149+
}
150+
151+
items["flux_state"] = ocp.args.PyTreeSave(train_states["flux_state"])
152+
153+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
154+
155+
def load_params(self, step=None):
156+
157+
self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX
158+
159+
def load_checkpoint(self, step=None, scheduler_class=None):
160+
clip_encoder = FlaxCLIPTextModel.from_pretrained(
161+
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
162+
)
163+
clip_tokenizer = CLIPTokenizer.from_pretrained(
164+
self.config.clip_model_name_or_path, max_length=77, use_fast=True
165+
)
166+
167+
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
168+
t5_tokenizer = AutoTokenizer.from_pretrained(
169+
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
170+
)
171+
encoders_sharding = PositionalSharding(self.devices_array).replicate()
172+
partial_device_put_replicated = functools.partial(max_utils.device_put_replicated, sharding=encoders_sharding)
173+
clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_encoder.params)
174+
clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_encoder.params)
175+
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
176+
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)
177+
178+
179+
180+
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
181+
self.config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16"
182+
)
183+
184+
flash_block_sizes = max_utils.get_flash_block_sizes(self.config)
185+
# loading from pretrained here causes a crash when trying to compile the model
186+
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
187+
transformer = FluxTransformer2DModel.from_config(
188+
self.config.pretrained_model_name_or_path,
189+
subfolder="transformer",
190+
mesh=self.mesh,
191+
split_head_dim=self.config.split_head_dim,
192+
attention_kernel=self.config.attention,
193+
flash_block_sizes=flash_block_sizes,
194+
dtype=self.config.activations_dtype,
195+
weights_dtype=self.config.weights_dtype,
196+
precision=max_utils.get_precision(self.config),
197+
)
198+
199+
return FluxPipeline(t5_encoder,
200+
clip_encoder,
201+
vae,
202+
t5_tokenizer,
203+
clip_tokenizer,
204+
transformer,
205+
None,
206+
dtype=self.config.activations_dtype,
207+
mesh=self.mesh,
208+
config=self.config,
209+
rng=self.rng), vae_params
210+

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ precision: "DEFAULT"
5454
# Set true to load weights from pytorch
5555
from_pt: True
5656
split_head_dim: True
57-
attention: 'flash' # Supported attention: dot_product, flash
57+
attention: 'dot_product' # Supported attention: dot_product, flash
5858

5959
flash_block_sizes: {}
6060
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
@@ -73,7 +73,7 @@ norm_num_groups: 32
7373

7474
# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
7575
# else they will be loaded from pretrained_model_name_or_path
76-
train_new_unet: False
76+
train_new_flux: False
7777

7878
# train text_encoder - Currently not supported for SDXL
7979
train_text_encoder: False
@@ -111,7 +111,7 @@ diffusion_scheduler_config: {
111111
base_output_directory: ""
112112

113113
# Hardware
114-
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
114+
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'
115115

116116
# Parallelism
117117
mesh_axes: ['data', 'fsdp', 'tensor']
@@ -173,7 +173,7 @@ hf_train_files: ''
173173
hf_access_token: ''
174174
image_column: 'image'
175175
caption_column: 'text'
176-
resolution: 1024
176+
resolution: 512
177177
center_crop: False
178178
random_flip: False
179179
# If cache_latents_text_encoder_outputs is True
@@ -189,17 +189,17 @@ checkpoint_every: -1
189189
enable_single_replica_ckpt_restoring: False
190190

191191
# Training loop
192-
learning_rate: 4.e-7
192+
learning_rate: 1.e-5
193193
scale_lr: False
194194
max_train_samples: -1
195195
# max_train_steps takes priority over num_train_epochs.
196-
max_train_steps: 200
196+
max_train_steps: 1500
197197
num_train_epochs: 1
198198
seed: 0
199199
output_dir: 'sdxl-model-finetuned'
200200
per_device_batch_size: 1
201201

202-
warmup_steps_fraction: 0.0
202+
warmup_steps_fraction: 0.1
203203
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
204204

205205
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
@@ -209,7 +209,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set
209209
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
210210
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
211211
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
212-
adam_weight_decay: 1.e-2 # AdamW Weight decay
212+
adam_weight_decay: 0 # AdamW Weight decay
213213
max_grad_norm: 1.0
214214

215215
enable_profiler: False
@@ -219,14 +219,15 @@ skip_first_n_steps_for_profiler: 5
219219
profiler_steps: 10
220220

221221
# Generation parameters
222-
prompt: "A magical castle in the middle of a forest, artistic drawing"
223-
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
222+
prompt: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
223+
prompt_2: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
224224
negative_prompt: "purple, red"
225225
do_classifier_free_guidance: True
226226
guidance_scale: 3.5
227227
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
228228
guidance_rescale: 0.0
229229
num_inference_steps: 50
230+
save_final_checkpoint: False
230231

231232
# SDXL Lightning parameters
232233
lightning_from_pt: True

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ precision: "DEFAULT"
5353
# Set true to load weights from pytorch
5454
from_pt: True
5555
split_head_dim: True
56-
attention: 'flash' # Supported attention: dot_product, flash
56+
attention: 'dot_product' # Supported attention: dot_product, flash
5757
flash_block_sizes: {
5858
"block_q" : 256,
5959
"block_kv_compute" : 256,
@@ -119,7 +119,7 @@ diffusion_scheduler_config: {
119119
base_output_directory: ""
120120

121121
# Hardware
122-
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
122+
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'
123123

124124
# Parallelism
125125
mesh_axes: ['data', 'fsdp', 'tensor']
@@ -234,7 +234,7 @@ do_classifier_free_guidance: True
234234
guidance_scale: 0.0
235235
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
236236
guidance_rescale: 0.0
237-
num_inference_steps: 4
237+
num_inference_steps: 50
238238

239239
# SDXL Lightning parameters
240240
lightning_from_pt: True

0 commit comments

Comments
 (0)