Skip to content

Commit 41e901c

Browse files
authored
Flux lora (#148)
Adds support for loading LoRA for inference for Flux.
1 parent 7f0f5bc commit 41e901c

10 files changed

Lines changed: 365 additions & 20 deletions

File tree

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/02/12`**: Flux LoRA for inference.
2021
- **`2025/02/08`**: Flux schnell & dev inference.
2122
- **`2024/12/12`**: Load multiple LoRAs for inference.
2223
- **`2024/10/22`**: LoRA support for Hyper SDXL.
@@ -47,7 +48,8 @@ MaxDiffusion supports
4748
* [Training](#training)
4849
* [Dreambooth](#dreambooth)
4950
* [Inference](#inference)
50-
* [Flux](#flux)
51+
* [Flux](#flux)
52+
* [Flux LoRA](#flux-lora)
5153
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
5254
* [Load Multiple LoRA](#load-multiple-lora)
5355
* [SDXL Lightning](#sdxl-lightning)
@@ -169,6 +171,24 @@ To generate images, run the following command:
169171
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
170172
```
171173

174+
## Flux LoRA
175+
176+
Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.
177+
178+
Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection.
179+
180+
First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows:
181+
182+
```bash
183+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'
184+
```
185+
186+
Loading multiple LoRAs is supported as follows:
187+
188+
```bash
189+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors", "/home/jfacevedo/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}'
190+
```
191+
172192
## Hyper SDXL LoRA
173193
174194
Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD)

src/maxdiffusion/generate_flux.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Callable, List, Union, Sequence
1818
from absl import app
19+
from contextlib import ExitStack
1920
import functools
2021
import math
2122
import time
@@ -24,6 +25,7 @@
2425
import jax
2526
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
2627
import jax.numpy as jnp
28+
import flax.linen as nn
2729
from chex import Array
2830
from einops import rearrange
2931
from flax.linen import partitioning as nn_partitioning
@@ -39,6 +41,28 @@
3941
get_precision,
4042
setup_initial_state,
4143
)
44+
from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin
45+
46+
47+
def maybe_load_flux_lora(config, lora_loader, params):
48+
def _noop_interceptor(next_fn, args, kwargs, context):
49+
return next_fn(*args, **kwargs)
50+
51+
lora_config = config.lora_config
52+
interceptors = [_noop_interceptor]
53+
if len(lora_config["lora_model_name_or_path"]) > 0:
54+
interceptors = []
55+
for i in range(len(lora_config["lora_model_name_or_path"])):
56+
params, rank, network_alphas = lora_loader.load_lora_weights(
57+
config,
58+
lora_config["lora_model_name_or_path"][i],
59+
weight_name=lora_config["weight_name"][i],
60+
params=params,
61+
adapter_name=lora_config["adapter_name"][i],
62+
)
63+
interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i])
64+
interceptors.append(interceptor)
65+
return params, interceptors
4266

4367

4468
def unpack(x: Array, height: int, width: int) -> Array:
@@ -97,7 +121,6 @@ def prepare_latent_image_ids(height, width):
97121
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
98122

99123
latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels)
100-
101124
return latent_image_ids.astype(jnp.bfloat16)
102125

103126

@@ -127,7 +150,6 @@ def run_inference(
127150
vec=vec,
128151
guidance_vec=guidance_vec,
129152
)
130-
131153
vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config)
132154

133155
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
@@ -373,21 +395,29 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
373395

374396
# loads pretrained weights
375397
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
398+
params = {}
399+
params["transformer"] = transformer_params
400+
# maybe load lora and create interceptor
401+
lora_loader = FluxLoraLoaderMixin()
402+
params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params)
403+
transformer_params = params["transformer"]
376404
# create transformer state
377405
weights_init_fn = functools.partial(
378406
transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False
379407
)
380-
transformer_state, transformer_state_shardings = setup_initial_state(
381-
model=transformer,
382-
tx=None,
383-
config=config,
384-
mesh=mesh,
385-
weights_init_fn=weights_init_fn,
386-
model_params=None,
387-
training=False,
388-
)
389-
transformer_state = transformer_state.replace(params=transformer_params)
390-
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
408+
with ExitStack() as stack:
409+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
410+
transformer_state, transformer_state_shardings = setup_initial_state(
411+
model=transformer,
412+
tx=None,
413+
config=config,
414+
mesh=mesh,
415+
weights_init_fn=weights_init_fn,
416+
model_params=None,
417+
training=False,
418+
)
419+
transformer_state = transformer_state.replace(params=transformer_params)
420+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
391421
get_memory_allocations()
392422

393423
states = {}
@@ -432,17 +462,23 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
432462
out_shardings=None,
433463
)
434464
t0 = time.perf_counter()
435-
p_run_inference(states).block_until_ready()
465+
with ExitStack() as stack:
466+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
467+
p_run_inference(states).block_until_ready()
436468
t1 = time.perf_counter()
437469
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
438470

439471
t0 = time.perf_counter()
440-
imgs = p_run_inference(states).block_until_ready()
472+
with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"):
473+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
474+
imgs = p_run_inference(states).block_until_ready()
441475
t1 = time.perf_counter()
442476
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
443477

444478
t0 = time.perf_counter()
445-
imgs = p_run_inference(states).block_until_ready()
479+
with ExitStack() as stack:
480+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
481+
imgs = p_run_inference(states).block_until_ready()
446482
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
447483
t1 = time.perf_counter()
448484
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
@@ -453,6 +489,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
453489
for i, image in enumerate(imgs):
454490
Image.fromarray(image).save(f"flux_{i}.png")
455491

492+
return imgs
493+
456494

457495
def main(argv: Sequence[str]) -> None:
458496
pyconfig.initialize(argv)

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
# limitations under the License.
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
16+
from .flux_lora_pipeline import FluxLoraLoaderMixin
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Union, Dict
16+
from .lora_base import LoRABaseMixin
17+
from ..models.lora import LoRALinearLayer, BaseLoRALayer
18+
import jax.numpy as jnp
19+
from flax.traverse_util import flatten_dict
20+
from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax
21+
from huggingface_hub.utils import validate_hf_hub_args
22+
23+
24+
class FluxLoraLoaderMixin(LoRABaseMixin):
25+
26+
_lora_lodable_modules = ["transformer", "text_encoder"]
27+
28+
def load_lora_weights(
29+
self,
30+
config,
31+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]],
32+
params,
33+
adapter_name=None,
34+
**kwargs,
35+
):
36+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
37+
38+
params, rank, network_alphas = self.load_lora(
39+
config,
40+
state_dict,
41+
params=params,
42+
adapter_name=adapter_name,
43+
)
44+
45+
return params, rank, network_alphas
46+
47+
def rename_for_interceptor(params_keys, network_alphas, adapter_name):
48+
new_params_keys = []
49+
new_network_alphas = {}
50+
lora_name = f"lora-{adapter_name}"
51+
for layer_lora in params_keys:
52+
if lora_name in layer_lora:
53+
new_layer_lora = layer_lora[: layer_lora.index(lora_name)]
54+
if new_layer_lora not in new_params_keys:
55+
new_params_keys.append(new_layer_lora)
56+
network_alpha = network_alphas.get(layer_lora, None)
57+
new_network_alphas[new_layer_lora] = network_alpha
58+
return new_params_keys, new_network_alphas
59+
60+
@classmethod
61+
def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
62+
network_alphas_for_interceptor = {}
63+
64+
transformer_keys = flatten_dict(params["transformer"]).keys()
65+
lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name)
66+
network_alphas_for_interceptor.update(transformer_alphas)
67+
68+
def _intercept(next_fn, args, kwargs, context):
69+
mod = context.module
70+
while mod is not None:
71+
if isinstance(mod, BaseLoRALayer):
72+
return next_fn(*args, **kwargs)
73+
mod = mod.parent
74+
h = next_fn(*args, **kwargs)
75+
if context.method_name == "__call__":
76+
module_path = context.module.path
77+
if module_path in lora_keys:
78+
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor, adapter_name)
79+
return lora_layer(h, *args, **kwargs)
80+
return h
81+
82+
return _intercept
83+
84+
@classmethod
85+
def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name):
86+
network_alpha = network_alphas.get(module_path, None)
87+
lora_module = LoRALinearLayer(
88+
out_features=module.features,
89+
rank=rank,
90+
network_alpha=network_alpha,
91+
dtype=module.dtype,
92+
weights_dtype=module.param_dtype,
93+
precision=module.precision,
94+
name=f"lora-{adapter_name}",
95+
)
96+
return lora_module
97+
98+
@classmethod
99+
@validate_hf_hub_args
100+
def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):
101+
102+
cache_dir = kwargs.pop("cache_dir", None)
103+
force_download = kwargs.pop("force_download", False)
104+
proxies = kwargs.pop("proxies", None)
105+
local_files_only = kwargs.pop("local_files_only", None)
106+
use_auth_token = kwargs.pop("use_auth_token", None)
107+
revision = kwargs.pop("revision", None)
108+
subfolder = kwargs.pop("subfolder", None)
109+
weight_name = kwargs.pop("weight_name", None)
110+
use_safetensors = kwargs.pop("use_safetensors", None)
111+
resume_download = kwargs.pop("resume_download", False)
112+
113+
allow_pickle = False
114+
if use_safetensors is None:
115+
use_safetensors = True
116+
allow_pickle = True
117+
118+
user_agent = {
119+
"file_type": "attn_procs_weights",
120+
"framework": "pytorch",
121+
}
122+
123+
state_dict = cls._fetch_state_dict(
124+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path,
125+
weight_name=weight_name,
126+
use_safetensors=use_safetensors,
127+
local_files_only=local_files_only,
128+
cache_dir=cache_dir,
129+
force_download=force_download,
130+
resume_download=resume_download,
131+
proxies=proxies,
132+
use_auth_token=use_auth_token,
133+
revision=revision,
134+
subfolder=subfolder,
135+
user_agent=user_agent,
136+
allow_pickle=allow_pickle,
137+
)
138+
139+
return state_dict
140+
141+
@classmethod
142+
def load_lora(cls, config, state_dict, params, adapter_name=None):
143+
params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name)
144+
return params, rank, network_alphas

src/maxdiffusion/loaders/lora_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name):
134134

135135
@classmethod
136136
def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
137-
# Only unet interceptor supported for now.
137+
138138
network_alphas_for_interceptor = {}
139139

140140
unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys()

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
144144
hidden_states = self.linear2(attn_mlp)
145145
hidden_states = gate * hidden_states
146146
hidden_states = residual + hidden_states
147-
if hidden_states.dtype == jnp.float16 or hidden_states.dtype == jnp.bfloat16:
147+
if hidden_states.dtype == jnp.float16:
148148
hidden_states = jnp.clip(hidden_states, -65504, 65504)
149149

150150
return hidden_states, temb, image_rotary_emb
@@ -294,7 +294,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=
294294

295295
context_ff_output = self.txt_mlp(norm_encoder_hidden_states)
296296
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
297-
if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16:
297+
if encoder_hidden_states.dtype == jnp.float16:
298298
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
299299
return hidden_states, encoder_hidden_states, temb, image_rotary_emb
300300

0 commit comments

Comments
 (0)