Skip to content

Commit 91d7f5c

Browse files
jfacevedo-googleksikiric
authored andcommitted
Support other format loras. update readme. Run code_style.
1 parent 719e6db commit 91d7f5c

6 files changed

Lines changed: 51 additions & 25 deletions

File tree

README.md

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20-
- **`2025/02/08**: Flux schnell & dev inference.
20+
- **`2025/02/12`**: Flux LoRA for inference.
21+
- **`2025/02/08`**: Flux schnell & dev inference.
2122
- **`2024/12/12`**: Load multiple LoRAs for inference.
2223
- **`2024/10/22`**: LoRA support for Hyper SDXL.
2324
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
@@ -47,7 +48,8 @@ MaxDiffusion supports
4748
* [Training](#training)
4849
* [Dreambooth](#dreambooth)
4950
* [Inference](#inference)
50-
* [Flux](#flux)
51+
* [Flux](#flux)
52+
* [Flux LoRA](#flux-lora)
5153
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
5254
* [Load Multiple LoRA](#load-multiple-lora)
5355
* [SDXL Lightning](#sdxl-lightning)
@@ -169,6 +171,24 @@ To generate images, run the following command:
169171
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
170172
```
171173

174+
## Flux LoRA
175+
176+
Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.
177+
178+
Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection.
179+
180+
First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows:
181+
182+
```bash
183+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'
184+
```
185+
186+
Loading multiple LoRAs is supported as follows:
187+
188+
```bash
189+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors", "/home/jfacevedo/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}'
190+
```
191+
172192
## Hyper SDXL LoRA
173193
174194
Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD)

src/maxdiffusion/generate_flux.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,25 @@
4343
)
4444
from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin
4545

46+
4647
def maybe_load_flux_lora(config, lora_loader, params):
4748
def _noop_interceptor(next_fn, args, kwargs, context):
4849
return next_fn(*args, **kwargs)
4950

5051
lora_config = config.lora_config
51-
interceptors= [_noop_interceptor]
52+
interceptors = [_noop_interceptor]
5253
if len(lora_config["lora_model_name_or_path"]) > 0:
5354
interceptors = []
5455
for i in range(len(lora_config["lora_model_name_or_path"])):
5556
params, rank, network_alphas = lora_loader.load_lora_weights(
56-
config,
57-
lora_config["lora_model_name_or_path"][i],
58-
weight_name=lora_config["weight_name"][i],
59-
params=params,
60-
adapter_name=lora_config["adapter_name"][i],
57+
config,
58+
lora_config["lora_model_name_or_path"][i],
59+
weight_name=lora_config["weight_name"][i],
60+
params=params,
61+
adapter_name=lora_config["adapter_name"][i],
6162
)
6263
interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i])
6364
interceptors.append(interceptor)
64-
6565
return params, interceptors
6666

6767

@@ -501,6 +501,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
501501
for i, image in enumerate(imgs):
502502
Image.fromarray(image).save(f"flux_{i}.png")
503503

504+
return imgs
505+
504506

505507
def main(argv: Sequence[str]) -> None:
506508
pyconfig.initialize(argv)

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
16-
from .flux_lora_pipeline import FluxLoraLoaderMixin
16+
from .flux_lora_pipeline import FluxLoraLoaderMixin

src/maxdiffusion/loaders/flux_lora_pipeline.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,27 @@
2121
from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax
2222
from huggingface_hub.utils import validate_hf_hub_args
2323
from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor)
24+
25+
2426
class FluxLoraLoaderMixin(LoRABaseMixin):
2527

2628
_lora_lodable_modules = ["transformer", "text_encoder"]
27-
29+
2830
def load_lora_weights(
2931
self,
3032
config,
3133
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]],
3234
params,
3335
adapter_name=None,
34-
**kwargs
36+
**kwargs,
3537
):
3638
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3739

3840
params, rank, network_alphas = self.load_lora(
39-
config,
40-
state_dict,
41-
params=params,
42-
adapter_name=adapter_name,
41+
config,
42+
state_dict,
43+
params=params,
44+
adapter_name=adapter_name,
4345
)
4446

4547
return params, rank, network_alphas
@@ -64,7 +66,7 @@ def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
6466
transformer_keys = flatten_dict(params["transformer"]).keys()
6567
lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name)
6668
network_alphas_for_interceptor.update(transformer_alphas)
67-
69+
6870
def _intercept(next_fn, args, kwargs, context):
6971
mod = context.module
7072
while mod is not None:
@@ -138,8 +140,8 @@ def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):
138140
)
139141

140142
return state_dict
141-
143+
142144
@classmethod
143145
def load_lora(cls, config, state_dict, params, adapter_name=None):
144146
params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name)
145-
return params, rank, network_alphas
147+
return params, rank, network_alphas

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def create_flax_params_from_pytorch_state(
222222
renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value
223223
return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas
224224

225+
225226
def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name):
226227
pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()}
227228
transformer_params = flatten_dict(unfreeze(params["transformer"]))
@@ -243,7 +244,7 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params,
243244
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.up", f"attn.i_qkv.lora-{adapter_name}.up")
244245
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.down", f"attn.e_qkv.lora-{adapter_name}.down")
245246
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.up", f"attn.e_qkv.lora-{adapter_name}.up")
246-
247+
247248
renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj")
248249
renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv")
249250
renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0")
@@ -258,20 +259,20 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params,
258259
renamed_pt_key = renamed_pt_key.replace("_linear1", ".linear1")
259260
renamed_pt_key = renamed_pt_key.replace("_linear2", ".linear2")
260261
renamed_pt_key = renamed_pt_key.replace("_modulation_lin", ".norm.lin")
261-
262+
262263
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
263-
264+
264265
pt_tuple_key = tuple(renamed_pt_key.split("."))
265266
if "alpha" in pt_tuple_key:
266-
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'down', 'kernel')
267+
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "down", "kernel")
267268
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
268-
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'up', 'kernel')
269+
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "up", "kernel")
269270
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
270271
else:
271272
if pt_tuple_key[-2] == "up":
272273
rank = tensor.shape[1]
273274
transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype)
274-
275+
275276
params["transformer"] = unflatten_dict(transformer_params)
276277

277278
return params, rank, network_alphas

src/maxdiffusion/tests/generate_flux_smoke_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def download_blob(gcs_file, local_file):
2222
gcs_dir_arr = gcs_file.replace("gs://", "").split("/")
2323
storage_client = storage.Client()
2424
bucket = storage_client.get_bucket(gcs_dir_arr[0])
25+
blob_loc = "/".join(gcs_dir_arr[1:])
2526
blob = bucket.blob("/".join(gcs_dir_arr[1:]))
2627
blob.download_to_filename(local_file)
2728

0 commit comments

Comments
 (0)