Skip to content

Commit 1f2e65c

Browse files
Support other format loras. update readme. Run code_style.
1 parent 4c68d53 commit 1f2e65c

9 files changed

Lines changed: 139 additions & 26 deletions

File tree

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/02/12`**: Flux LoRA for inference.
2021
- **`2025/02/08`**: Flux schnell & dev inference.
2122
- **`2024/12/12`**: Load multiple LoRAs for inference.
2223
- **`2024/10/22`**: LoRA support for Hyper SDXL.
@@ -47,7 +48,8 @@ MaxDiffusion supports
4748
* [Training](#training)
4849
* [Dreambooth](#dreambooth)
4950
* [Inference](#inference)
50-
* [Flux](#flux)
51+
* [Flux](#flux)
52+
* [Flux LoRA](#flux-lora)
5153
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
5254
* [Load Multiple LoRA](#load-multiple-lora)
5355
* [SDXL Lightning](#sdxl-lightning)
@@ -169,6 +171,24 @@ To generate images, run the following command:
169171
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
170172
```
171173

174+
## Flux LoRA
175+
176+
Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.
177+
178+
Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection.
179+
180+
First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows:
181+
182+
```bash
183+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'
184+
```
185+
186+
Loading multiple LoRAs is supported as follows:
187+
188+
```bash
189+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors", "/home/jfacevedo/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}'
190+
```
191+
172192
## Hyper SDXL LoRA
173193
174194
Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD)

src/maxdiffusion/generate_flux.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,25 @@
4343
)
4444
from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin
4545

46+
4647
def maybe_load_flux_lora(config, lora_loader, params):
4748
def _noop_interceptor(next_fn, args, kwargs, context):
4849
return next_fn(*args, **kwargs)
4950

5051
lora_config = config.lora_config
51-
interceptors= [_noop_interceptor]
52+
interceptors = [_noop_interceptor]
5253
if len(lora_config["lora_model_name_or_path"]) > 0:
5354
interceptors = []
5455
for i in range(len(lora_config["lora_model_name_or_path"])):
5556
params, rank, network_alphas = lora_loader.load_lora_weights(
56-
config,
57-
lora_config["lora_model_name_or_path"][i],
58-
weight_name=lora_config["weight_name"][i],
59-
params=params,
60-
adapter_name=lora_config["adapter_name"][i],
57+
config,
58+
lora_config["lora_model_name_or_path"][i],
59+
weight_name=lora_config["weight_name"][i],
60+
params=params,
61+
adapter_name=lora_config["adapter_name"][i],
6162
)
6263
interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i])
6364
interceptors.append(interceptor)
64-
6565
return params, interceptors
6666

6767

@@ -489,6 +489,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
489489
for i, image in enumerate(imgs):
490490
Image.fromarray(image).save(f"flux_{i}.png")
491491

492+
return imgs
493+
492494

493495
def main(argv: Sequence[str]) -> None:
494496
pyconfig.initialize(argv)

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
16-
from .flux_lora_pipeline import FluxLoraLoaderMixin
16+
from .flux_lora_pipeline import FluxLoraLoaderMixin

src/maxdiffusion/loaders/flux_lora_pipeline.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,27 @@
2121
from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax
2222
from huggingface_hub.utils import validate_hf_hub_args
2323
from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor)
24+
25+
2426
class FluxLoraLoaderMixin(LoRABaseMixin):
2527

2628
_lora_lodable_modules = ["transformer", "text_encoder"]
27-
29+
2830
def load_lora_weights(
2931
self,
3032
config,
3133
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]],
3234
params,
3335
adapter_name=None,
34-
**kwargs
36+
**kwargs,
3537
):
3638
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3739

3840
params, rank, network_alphas = self.load_lora(
39-
config,
40-
state_dict,
41-
params=params,
42-
adapter_name=adapter_name,
41+
config,
42+
state_dict,
43+
params=params,
44+
adapter_name=adapter_name,
4345
)
4446

4547
return params, rank, network_alphas
@@ -64,7 +66,7 @@ def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
6466
transformer_keys = flatten_dict(params["transformer"]).keys()
6567
lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name)
6668
network_alphas_for_interceptor.update(transformer_alphas)
67-
69+
6870
def _intercept(next_fn, args, kwargs, context):
6971
mod = context.module
7072
while mod is not None:
@@ -138,8 +140,8 @@ def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):
138140
)
139141

140142
return state_dict
141-
143+
142144
@classmethod
143145
def load_lora(cls, config, state_dict, params, adapter_name=None):
144146
params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name)
145-
return params, rank, network_alphas
147+
return params, rank, network_alphas

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
144144
hidden_states = self.linear2(attn_mlp)
145145
hidden_states = gate * hidden_states
146146
hidden_states = residual + hidden_states
147-
if hidden_states.dtype == jnp.float16 or hidden_states.dtype == jnp.bfloat16:
147+
if hidden_states.dtype == jnp.float16:
148148
hidden_states = jnp.clip(hidden_states, -65504, 65504)
149149

150150
return hidden_states, temb, image_rotary_emb
@@ -294,7 +294,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=
294294

295295
context_ff_output = self.txt_mlp(norm_encoder_hidden_states)
296296
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
297-
if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16:
297+
if encoder_hidden_states.dtype == jnp.float16:
298298
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
299299
return hidden_states, encoder_hidden_states, temb, image_rotary_emb
300300

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def create_flax_params_from_pytorch_state(
222222
renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value
223223
return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas
224224

225+
225226
def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name):
226227
pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()}
227228
transformer_params = flatten_dict(unfreeze(params["transformer"]))
@@ -243,7 +244,7 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params,
243244
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.up", f"attn.i_qkv.lora-{adapter_name}.up")
244245
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.down", f"attn.e_qkv.lora-{adapter_name}.down")
245246
renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.up", f"attn.e_qkv.lora-{adapter_name}.up")
246-
247+
247248
renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj")
248249
renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv")
249250
renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0")
@@ -258,20 +259,20 @@ def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params,
258259
renamed_pt_key = renamed_pt_key.replace("_linear1", ".linear1")
259260
renamed_pt_key = renamed_pt_key.replace("_linear2", ".linear2")
260261
renamed_pt_key = renamed_pt_key.replace("_modulation_lin", ".norm.lin")
261-
262+
262263
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
263-
264+
264265
pt_tuple_key = tuple(renamed_pt_key.split("."))
265266
if "alpha" in pt_tuple_key:
266-
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'down', 'kernel')
267+
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "down", "kernel")
267268
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
268-
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'up', 'kernel')
269+
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "up", "kernel")
269270
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
270271
else:
271272
if pt_tuple_key[-2] == "up":
272273
rank = tensor.shape[1]
273274
transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype)
274-
275+
275276
params["transformer"] = unflatten_dict(transformer_params)
276277

277278
return params, rank, network_alphas
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import os
2+
import unittest
3+
import pytest
4+
5+
import numpy as np
6+
7+
from .. import pyconfig
8+
from absl.testing import absltest
9+
from maxdiffusion.generate_flux import run as generate_flux
10+
from PIL import Image
11+
from skimage.metrics import structural_similarity as ssim
12+
from google.cloud import storage
13+
14+
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
15+
16+
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
17+
18+
JAX_CACHE_DIR = "gs://maxdiffusion-github-runner-test-assets/cache_dir"
19+
20+
21+
def download_blob(gcs_file, local_file):
22+
gcs_dir_arr = gcs_file.replace("gs://", "").split("/")
23+
storage_client = storage.Client()
24+
bucket = storage_client.get_bucket(gcs_dir_arr[0])
25+
blob_loc = "/".join(gcs_dir_arr[1:])
26+
blob = bucket.blob("/".join(gcs_dir_arr[1:]))
27+
blob.download_to_filename(local_file)
28+
29+
30+
class GenerateFlux(unittest.TestCase):
31+
"""Smoke test."""
32+
33+
def setUp(self):
34+
GenerateFlux.dummy_data = {}
35+
36+
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
37+
def test_flux_dev(self):
38+
img_url = os.path.join(THIS_DIR, "images", "test_flux_dev.png")
39+
base_image = np.array(Image.open(img_url)).astype(np.uint8)
40+
pyconfig.initialize(
41+
[
42+
None,
43+
os.path.join(THIS_DIR, "..", "configs", "base_flux_dev.yml"),
44+
"run_name=flux_test",
45+
"output_dir=/tmp/",
46+
"jax_cache_dir=/tmp/cache_dir",
47+
'prompt="A cute corgi lives in a house made out of sushi, anime"',
48+
],
49+
unittest=True,
50+
)
51+
52+
images = generate_flux(pyconfig.config)
53+
test_image = np.array(images[0]).astype(np.uint8)
54+
ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255)
55+
assert base_image.shape == test_image.shape
56+
assert ssim_compare >= 0.80
57+
58+
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
59+
def test_flux_dev_lora(self):
60+
img_url = os.path.join(THIS_DIR, "images", "test_flux_dev_lora.png")
61+
base_image = np.array(Image.open(img_url)).astype(np.uint8)
62+
63+
gcs_lora_path = "gs://maxdiffusion-github-runner-test-assets/flux/lora/anime_lora.safetensors"
64+
local_path = "/tmp/anime_lora.safetensors"
65+
download_blob(gcs_lora_path, local_path)
66+
67+
pyconfig.initialize(
68+
[
69+
None,
70+
os.path.join(THIS_DIR, "..", "configs", "base_flux_dev.yml"),
71+
"run_name=flux_test",
72+
"output_dir=/tmp/",
73+
"jax_cache_dir=/tmp/cache_dir",
74+
'prompt="A cute corgi lives in a house made out of sushi, anime"',
75+
'lora_config={"lora_model_name_or_path" : ["/tmp/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}',
76+
],
77+
unittest=True,
78+
)
79+
80+
images = generate_flux(pyconfig.config)
81+
test_image = np.array(images[1]).astype(np.uint8)
82+
ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255)
83+
assert base_image.shape == test_image.shape
84+
assert ssim_compare >= 0.80
85+
86+
87+
if __name__ == "__main__":
88+
absltest.main()
753 KB
Loading
1.26 MB
Loading

0 commit comments

Comments
 (0)