Skip to content

Commit efa11e4

Browse files
authored
Support multi lora loading (#137)
* adds multi lora support * update readme
1 parent b1c63d7 commit efa11e4

5 files changed

Lines changed: 68 additions & 40 deletions

File tree

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2024/12/12`**: Load multiple LoRAs for inference.
2021
- **`2024/10/22`**: LoRA support for Hyper SDXL.
2122
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
2223
- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported.
@@ -33,6 +34,7 @@ MaxDiffusion supports
3334
* Stable Diffusion XL (training and inference).
3435
* Stable Diffusion Lightning (inference).
3536
* Hyper-SD XL LoRA loading (inference).
37+
* Load Multiple LoRA (SDXL inference).
3638
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
3739
* Dreambooth training support for Stable Diffusion 1.x,2.x.
3840

@@ -45,6 +47,7 @@ MaxDiffusion supports
4547
* [Dreambooth](#dreambooth)
4648
* [Inference](#inference)
4749
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
50+
* [Load Multiple LoRA](#load-multiple-lora)
4851
* [SDXL Lightning](#sdxl-lightning)
4952
* [ControlNet](#controlnet)
5053
* [Comparison To Alternatives](#comparison-to-alternatives)
@@ -139,6 +142,14 @@ To generate images, run the following command:
139142
python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=2 do_classifier_free_guidance=False prompt="a photograph of a cat wearing a hat riding a skateboard in a park." per_device_batch_size=1 pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" from_pt=True revision=main diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}'
140143
```
141144

145+
## Load Multiple LoRA
146+
147+
Supports loading multiple LoRAs for inference. Both from local or from HuggingFace hub.
148+
149+
```bash
150+
python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=30 do_classifier_free_guidance=True prompt="ultra detailed diagram blueprint of a papercut Sitting MaineCoon cat, wide canvas, ampereart, electrical diagram, bl3uprint, papercut" per_device_batch_size=1 diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","TheLastBen/Papercut_SDXL"], "weight_name" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","papercut.safetensors"], "adapter_name" : ["blueprint","papercut"], "scale": [0.8, 0.7], "from_pt": ["true", "true"]}'
151+
```
152+
142153
## SDXL Lightning
143154

144155
Single and Multi host inference is supported with sharding annotations:

src/maxdiffusion/generate_sdxl.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import functools
1818
from absl import app
19+
from contextlib import ExitStack
1920
from typing import Sequence
2021
import time
2122

@@ -233,14 +234,15 @@ def run(config):
233234
params["unet"] = unet_params
234235

235236
# maybe load lora and create interceptor
236-
params, lora_interceptor = maybe_load_lora(config, pipeline, params)
237+
params, lora_interceptors = maybe_load_lora(config, pipeline, params)
237238

238239
if config.lightning_repo:
239240
pipeline, params = load_sdxllightning_unet(config, pipeline, params)
240241

241242
# Don't restore the full train state, instead, just restore params
242243
# and create an inference state.
243-
with nn.intercept_methods(lora_interceptor):
244+
with ExitStack() as stack:
245+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
244246
unet_state, unet_state_shardings = max_utils.setup_initial_state(
245247
model=pipeline.unet,
246248
tx=None,
@@ -254,7 +256,8 @@ def run(config):
254256
vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
255257
pipeline, params, checkpoint_item_name="vae_state", is_training=False
256258
)
257-
with nn.intercept_methods(lora_interceptor):
259+
with ExitStack() as stack:
260+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
258261
text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state(
259262
pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False
260263
)
@@ -293,11 +296,13 @@ def run(config):
293296
)
294297

295298
s = time.time()
296-
with nn.intercept_methods(lora_interceptor):
299+
with ExitStack() as stack:
300+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
297301
p_run_inference(states).block_until_ready()
298302
print("compile time: ", (time.time() - s))
299303
s = time.time()
300-
with nn.intercept_methods(lora_interceptor):
304+
with ExitStack() as stack:
305+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
301306
images = p_run_inference(states).block_until_ready()
302307
print("inference time: ", (time.time() - s))
303308
images = jax.experimental.multihost_utils.process_allgather(images)

src/maxdiffusion/loaders/lora_pipeline.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def load_lora_weights(
8888
return params, rank, network_alphas
8989

9090
@classmethod
91-
def _get_lora_layer(cls, module_path, module, rank, network_alphas):
91+
def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name):
9292
is_conv = any("conv" in str_ for str_ in module_path)
9393
network_alpha = network_alphas.get(module_path, None)
9494
if is_conv:
@@ -105,7 +105,7 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):
105105
dtype=module.dtype,
106106
weights_dtype=module.param_dtype,
107107
precision=module.precision,
108-
name="lora",
108+
name=f"lora-{adapter_name}",
109109
)
110110
else:
111111
lora_module = LoRALinearLayer(
@@ -115,39 +115,41 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):
115115
dtype=module.dtype,
116116
weights_dtype=module.param_dtype,
117117
precision=module.precision,
118-
name="lora",
118+
name=f"lora-{adapter_name}",
119119
)
120120
return lora_module
121121

122-
def rename_for_interceptor(params_keys, network_alphas):
122+
def rename_for_interceptor(params_keys, network_alphas, adapter_name):
123123
new_params_keys = []
124124
new_network_alphas = {}
125+
lora_name = f"lora-{adapter_name}"
125126
for layer_lora in params_keys:
126-
if "lora" in layer_lora:
127-
new_layer_lora = layer_lora[: layer_lora.index("lora")]
127+
if lora_name in layer_lora:
128+
new_layer_lora = layer_lora[: layer_lora.index(lora_name)]
128129
if new_layer_lora not in new_params_keys:
129130
new_params_keys.append(new_layer_lora)
130131
network_alpha = network_alphas[layer_lora]
131132
new_network_alphas[new_layer_lora] = network_alpha
132133
return new_params_keys, new_network_alphas
133134

134135
@classmethod
135-
def make_lora_interceptor(cls, params, rank, network_alphas):
136+
def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
136137
# Only unet interceptor supported for now.
137138
network_alphas_for_interceptor = {}
138139

139140
unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys()
140-
lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas)
141+
lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas, adapter_name)
141142
network_alphas_for_interceptor.update(unet_alphas)
142143

143144
text_encoder_keys = flax.traverse_util.flatten_dict(params["text_encoder"]).keys()
144-
text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas)
145+
text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas, adapter_name)
145146
lora_keys.extend(text_encoder_keys)
146147
network_alphas_for_interceptor.update(text_encoder_alphas)
147-
148148
if "text_encoder_2" in params.keys():
149149
text_encoder_2_keys = flax.traverse_util.flatten_dict(params["text_encoder_2"]).keys()
150-
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas)
150+
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(
151+
text_encoder_2_keys, network_alphas, adapter_name
152+
)
151153
lora_keys.extend(text_encoder_2_keys)
152154
network_alphas_for_interceptor.update(text_encoder_2_alphas)
153155

@@ -161,7 +163,7 @@ def _intercept(next_fn, args, kwargs, context):
161163
if context.method_name == "__call__":
162164
module_path = context.module.path
163165
if module_path in lora_keys:
164-
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor)
166+
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor, adapter_name)
165167
return lora_layer(h, *args, **kwargs)
166168
return h
167169

@@ -290,5 +292,5 @@ def load_lora(cls, state_dict, network_alphas, params, adapter_name=None, _pipel
290292
`default_{i}` where i is the total number of adapters being loaded.
291293
"""
292294
# Load the layers corresponding to Unet.
293-
params, rank, network_alphas = convert_lora_pytorch_state_dict_to_flax(state_dict, params, network_alphas)
295+
params, rank, network_alphas = convert_lora_pytorch_state_dict_to_flax(state_dict, params, network_alphas, adapter_name)
294296
return params, rank, network_alphas

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,24 @@ def _noop_interceptor(next_fn, args, kwargs, context):
4545
return next_fn(*args, **kwargs)
4646

4747
lora_config = config.lora_config
48-
interceptor = _noop_interceptor
48+
interceptors = [_noop_interceptor]
4949
if len(lora_config["lora_model_name_or_path"]) > 0:
5050
# For now only first lora supported. In the future, they will be merged
5151
# before being loaded.
5252
# TODO - merge LoRAs here.
53-
params, rank, network_alphas = pipeline.load_lora_weights(
54-
lora_config["lora_model_name_or_path"][0],
55-
weight_name=lora_config["weight_name"][0],
56-
params=params,
57-
adapter_name=lora_config["adapter_name"][0],
58-
unet_config=pipeline.unet.config,
59-
)
60-
interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas)
53+
interceptors = []
54+
for i in range(len(lora_config["lora_model_name_or_path"])):
55+
params, rank, network_alphas = pipeline.load_lora_weights(
56+
lora_config["lora_model_name_or_path"][i],
57+
weight_name=lora_config["weight_name"][i],
58+
params=params,
59+
adapter_name=lora_config["adapter_name"][i],
60+
unet_config=pipeline.unet.config,
61+
)
62+
interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i])
63+
interceptors.append(interceptor)
6164

62-
return params, interceptor
65+
return params, interceptors
6366

6467

6568
def vae_apply(images, sample_rng, vae, vae_params):

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,13 @@ def get_network_alpha_value(pt_key, network_alphas):
130130

131131

132132
def create_flax_params_from_pytorch_state(
133-
pt_state_dict, unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, network_alphas, is_lora=False
133+
pt_state_dict,
134+
unet_state_dict,
135+
text_encoder_state_dict,
136+
text_encoder_2_state_dict,
137+
network_alphas,
138+
adapter_name,
139+
is_lora=False,
134140
):
135141
rank = None
136142
renamed_network_alphas = {}
@@ -157,19 +163,21 @@ def create_flax_params_from_pytorch_state(
157163
flax_key_list = [*pt_tuple_key]
158164
if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key:
159165
rename_from_to = (
160-
("to_k_lora", ("k_proj", "lora")),
161-
("to_q_lora", ("q_proj", "lora")),
162-
("to_v_lora", ("v_proj", "lora")),
163-
("to_out_lora", ("out_proj", "lora")),
166+
("to_k_lora", ("k_proj", f"lora-{adapter_name}")),
167+
("to_q_lora", ("q_proj", f"lora-{adapter_name}")),
168+
("to_v_lora", ("v_proj", f"lora-{adapter_name}")),
169+
("to_out_lora", ("out_proj", f"lora-{adapter_name}")),
170+
("lora", f"lora-{adapter_name}"),
164171
("weight", "kernel"),
165172
)
166173
# the unet
167174
else:
168175
rename_from_to = (
169-
("to_k_lora", ("to_k", "lora")),
170-
("to_q_lora", ("to_q", "lora")),
171-
("to_v_lora", ("to_v", "lora")),
172-
("to_out_lora", ("to_out_0", "lora")),
176+
("to_k_lora", ("to_k", f"lora-{adapter_name}")),
177+
("to_q_lora", ("to_q", f"lora-{adapter_name}")),
178+
("to_v_lora", ("to_v", f"lora-{adapter_name}")),
179+
("to_out_lora", ("to_out_0", f"lora-{adapter_name}")),
180+
("lora", f"lora-{adapter_name}"),
173181
("weight", "kernel"),
174182
)
175183
for rename_from, rename_to in rename_from_to:
@@ -206,11 +214,10 @@ def create_flax_params_from_pytorch_state(
206214

207215
if network_alpha_value >= 0:
208216
renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value
209-
210217
return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas
211218

212219

213-
def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas):
220+
def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas, adapter_name):
214221
# Step 1: Convert pytorch tensor to numpy
215222
# sometimes we load weights in bf16 and numpy doesn't support it
216223
pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()}
@@ -223,7 +230,7 @@ def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alpha
223230
text_encoder_2_params = None
224231
(unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, network_alphas) = (
225232
create_flax_params_from_pytorch_state(
226-
pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, is_lora=True
233+
pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, adapter_name, is_lora=True
227234
)
228235
)
229236
params["unet"] = unflatten_dict(unet_state_dict)

0 commit comments

Comments
 (0)