Skip to content

Commit 1deeca5

Browse files
Hyper SDXL Lora support (#127)
* Adds Hyper SDXL LoRA loading for inference using Flax interceptor. --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com>
1 parent 9238bc9 commit 1deeca5

26 files changed

Lines changed: 1441 additions & 58 deletions

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20-
21-
- **`2024/8/1`**: Orbax is the new default checkpointer for Stable Diffusion 1.X, 2.x. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
20+
- **`2024/10/22`**: LoRA support for Hyper SDXL.
21+
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
2222
- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported.
2323

2424
# Overview
@@ -32,6 +32,7 @@ MaxDiffusion supports
3232
* Stable Diffusion 2.1 (training and inference)
3333
* Stable Diffusion XL (training and inference).
3434
* Stable Diffusion Lightning (inference).
35+
* Hyper-SD XL LoRA loading (inference).
3536
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
3637
* Dreambooth training support for Stable Diffusion 1.x,2.x.
3738

@@ -43,6 +44,7 @@ MaxDiffusion supports
4344
* [Training](#training)
4445
* [Dreambooth](#dreambooth)
4546
* [Inference](#inference)
47+
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
4648
* [SDXL Lightning](#sdxl-lightning)
4749
* [ControlNet](#controlnet)
4850
* [Comparison To Alternatives](#comparison-to-alternatives)
@@ -129,6 +131,14 @@ To generate images, run the following command:
129131
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
130132
```
131133

134+
## Hyper SDXL LoRA
135+
136+
Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD)
137+
138+
```bash
139+
python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=2 do_classifier_free_guidance=False prompt="a photograph of a cat wearing a hat riding a skateboard in a park." per_device_batch_size=1 pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" from_pt=True revision=main diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}'
140+
```
141+
132142
## SDXL Lightning
133143

134144
Single and Multi host inference is supported with sharding annotations:

generated_image.png

-513 KB
Binary file not shown.

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
2727
orbax-checkpoint>=0.5.20
2828
tokenizers==0.20.0
29+
huggingface_hub==0.24.7
30+
2931
huggingface_hub==0.24.7

src/maxdiffusion/configs/base_xl.yml

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,10 @@ timestep_bias: {
7575

7676
# Override parameters from checkpoints's scheduler.
7777
diffusion_scheduler_config: {
78-
_class_name: '',
79-
# values are v_prediction or leave empty to use scheduler's default.
80-
prediction_type: '',
78+
_class_name: 'FlaxEulerDiscreteScheduler',
79+
prediction_type: 'epsilon',
8180
rescale_zero_terminal_snr: False,
82-
timestep_spacing: ''
81+
timestep_spacing: 'trailing'
8382
}
8483

8584
# Output directory
@@ -197,7 +196,7 @@ profiler_steps: 10
197196
prompt: "A magical castle in the middle of a forest, artistic drawing"
198197
negative_prompt: "purple, red"
199198
do_classifier_free_guidance: True
200-
guidance_scale: 9
199+
guidance_scale: 9.0
201200
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
202201
guidance_rescale: 0.0
203202
num_inference_steps: 20
@@ -209,6 +208,24 @@ lightning_repo: ""
209208
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
210209
lightning_ckpt: ""
211210

211+
# LoRA parameters
212+
# Values are lists to support multiple LoRA loading during inference in the future.
213+
lora_config: {
214+
lora_model_name_or_path: [],
215+
weight_name: [],
216+
adapter_name: [],
217+
scale: [],
218+
from_pt: []
219+
}
220+
# Ex with values:
221+
# lora_config : {
222+
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
223+
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
224+
# adapter_name: ["hyper-sdxl"],
225+
# scale: [0.7],
226+
# from_pt: [True]
227+
# }
228+
212229
enable_mllog: False
213230

214231
#controlnet

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ text_encoder_learning_rate: 4.25e-6
5656
diffusion_scheduler_config: {
5757
_class_name: 'DDIMScheduler',
5858
# values are v_prediction or leave empty to use scheduler's default.
59-
prediction_type: '',
59+
prediction_type: 'epsilon',
6060
rescale_zero_terminal_snr: False,
6161
timestep_spacing: 'trailing'
6262
}
@@ -156,7 +156,7 @@ profiler_steps: 5
156156
prompt: "portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal, elegant, sharp focus, soft lighting, vibrant colors"
157157
negative_prompt: "purple, red"
158158
do_classifier_free_guidance: False
159-
guidance_scale: 2
159+
guidance_scale: 2.0
160160
guidance_rescale: -1
161161
num_inference_steps: 4
162162

@@ -165,4 +165,22 @@ lightning_from_pt: True
165165
lightning_repo: "ByteDance/SDXL-Lightning"
166166
lightning_ckpt: "sdxl_lightning_4step_unet.safetensors"
167167

168+
# LoRA parameters
169+
# Values are lists to support multiple LoRA loading during inference in the future.
170+
lora_config: {
171+
lora_model_name_or_path: [],
172+
weight_name: [],
173+
adapter_name: [],
174+
scale: [],
175+
from_pt: []
176+
}
177+
# Ex with values:
178+
# lora_config : {
179+
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
180+
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
181+
# adapter_name: ["hyper-sdxl"],
182+
# scale: [0.7],
183+
# from_pt: [True]
184+
# }
185+
168186
enable_mllog: False

src/maxdiffusion/generate_sdxl.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@
2323
import jax
2424
import jax.numpy as jnp
2525
from jax.sharding import PartitionSpec as P
26+
import flax.linen as nn
2627
from flax.linen import partitioning as nn_partitioning
2728

28-
from maxdiffusion import (
29-
FlaxEulerDiscreteScheduler,
30-
)
31-
32-
3329
from maxdiffusion import pyconfig, max_utils
3430
from maxdiffusion.image_processor import VaeImageProcessor
35-
from maxdiffusion.maxdiffusion_utils import (get_add_time_ids, rescale_noise_cfg, load_sdxllightning_unet)
31+
from maxdiffusion.maxdiffusion_utils import (
32+
get_add_time_ids,
33+
rescale_noise_cfg,
34+
load_sdxllightning_unet,
35+
maybe_load_lora,
36+
create_scheduler,
37+
)
3638

3739
from maxdiffusion.trainers.sdxl_trainer import (StableDiffusionXLTrainer)
3840

@@ -82,7 +84,6 @@ def apply_classifier_free_guidance(noise_pred, guidance_scale):
8284
lambda _: noise_pred,
8385
operand=None,
8486
)
85-
8687
latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
8788

8889
return latents, scheduler_state, state
@@ -217,6 +218,8 @@ def run(config):
217218
checkpoint_loader = GenerateSDXL(config)
218219
pipeline, params = checkpoint_loader.load_checkpoint()
219220

221+
noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config)
222+
220223
weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng)
221224
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
222225
pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False
@@ -228,20 +231,24 @@ def run(config):
228231
if unet_params:
229232
params["unet"] = unet_params
230233

234+
# maybe load lora and create interceptor
235+
params, lora_interceptor = maybe_load_lora(config, pipeline, params)
236+
231237
if config.lightning_repo:
232238
pipeline, params = load_sdxllightning_unet(config, pipeline, params)
233239

234-
# Don't restore the train state to save memory, just restore params
240+
# Don't restore the full train state, instead, just restore params
235241
# and create an inference state.
236-
unet_state, unet_state_shardings = max_utils.setup_initial_state(
237-
model=pipeline.unet,
238-
tx=None,
239-
config=config,
240-
mesh=checkpoint_loader.mesh,
241-
weights_init_fn=weights_init_fn,
242-
model_params=params.get("unet", None),
243-
training=False,
244-
)
242+
with nn.intercept_methods(lora_interceptor):
243+
unet_state, unet_state_shardings = max_utils.setup_initial_state(
244+
model=pipeline.unet,
245+
tx=None,
246+
config=config,
247+
mesh=checkpoint_loader.mesh,
248+
weights_init_fn=weights_init_fn,
249+
model_params=params.get("unet", None),
250+
training=False,
251+
)
245252

246253
vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
247254
pipeline, params, checkpoint_item_name="vae_state", is_training=False
@@ -267,14 +274,6 @@ def run(config):
267274
states["text_encoder_state"] = text_encoder_state
268275
states["text_encoder_2_state"] = text_encoder_2_state
269276

270-
noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained(
271-
config.pretrained_model_name_or_path,
272-
revision=config.revision,
273-
subfolder="scheduler",
274-
dtype=jnp.float32,
275-
timestep_spacing="trailing",
276-
)
277-
278277
pipeline.scheduler = noise_scheduler
279278
params["scheduler"] = noise_scheduler_state
280279

@@ -293,10 +292,12 @@ def run(config):
293292
)
294293

295294
s = time.time()
296-
p_run_inference(states).block_until_ready()
295+
with nn.intercept_methods(lora_interceptor):
296+
p_run_inference(states).block_until_ready()
297297
print("compile time: ", (time.time() - s))
298298
s = time.time()
299-
images = p_run_inference(states).block_until_ready()
299+
with nn.intercept_methods(lora_interceptor):
300+
images = p_run_inference(states).block_until_ready()
300301
print("inference time: ", (time.time() - s))
301302
images = jax.experimental.multihost_utils.process_allgather(images)
302303
numpy_images = np.array(images)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .lora_pipeline import StableDiffusionLoraLoaderMixin
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..models.modeling_utils import load_state_dict
16+
from ..utils import _get_model_file
17+
18+
import safetensors
19+
20+
21+
class LoRABaseMixin:
22+
"""Utility class for handing LoRAs"""
23+
24+
_lora_lodable_modules = []
25+
num_fused_loras = 0
26+
27+
def load_lora_weights(self, **kwargs):
28+
raise NotImplementedError("`load_lora_weights()` is not implemented.")
29+
30+
@classmethod
31+
def _fetch_state_dict(
32+
cls,
33+
pretrained_model_name_or_path_or_dict,
34+
weight_name,
35+
use_safetensors,
36+
local_files_only,
37+
cache_dir,
38+
force_download,
39+
resume_download,
40+
proxies,
41+
use_auth_token,
42+
revision,
43+
subfolder,
44+
user_agent,
45+
allow_pickle,
46+
):
47+
from .lora_pipeline import LORA_WEIGHT_NAME_SAFE
48+
49+
model_file = None
50+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
51+
# Let's first try to load .safetensors weights
52+
if (use_safetensors and weight_name is None) or (weight_name is not None and weight_name.endswith(".safetensors")):
53+
try:
54+
# Here we're relaxing the loading check to enable more Inference API
55+
# friendliness where sometimes, it's not at all possible to automatically
56+
# determine `weight_name`.
57+
if weight_name is None:
58+
weight_name = cls._best_guess_weight_name(
59+
pretrained_model_name_or_path_or_dict,
60+
file_extension=".safetensors",
61+
local_files_only=local_files_only,
62+
)
63+
model_file = _get_model_file(
64+
pretrained_model_name_or_path_or_dict,
65+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
66+
cache_dir=cache_dir,
67+
force_download=force_download,
68+
resume_download=resume_download,
69+
proxies=proxies,
70+
local_files_only=local_files_only,
71+
use_auth_token=use_auth_token,
72+
revision=revision,
73+
subfolder=subfolder,
74+
user_agent=user_agent,
75+
)
76+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
77+
except (IOError, safetensors.SafetensorError) as e:
78+
if not allow_pickle:
79+
raise e
80+
# try loading non-safetensors weights
81+
model_file = None
82+
pass
83+
84+
if model_file is None:
85+
if weight_name is None:
86+
weight_name = cls._best_guess_weight_name(
87+
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
88+
)
89+
model_file = _get_model_file(
90+
pretrained_model_name_or_path_or_dict,
91+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
92+
cache_dir=cache_dir,
93+
force_download=force_download,
94+
resume_download=resume_download,
95+
proxies=proxies,
96+
local_files_only=local_files_only,
97+
use_auth_token=use_auth_token,
98+
revision=revision,
99+
subfolder=subfolder,
100+
user_agent=user_agent,
101+
)
102+
state_dict = load_state_dict(model_file)
103+
else:
104+
state_dict = pretrained_model_name_or_path_or_dict
105+
106+
return state_dict

0 commit comments

Comments
 (0)