Skip to content

Commit b4d0502

Browse files
initial lora implementation for flux
1 parent fa1c23b commit b4d0502

6 files changed

Lines changed: 254 additions & 35 deletions

File tree

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,17 @@ precision: "DEFAULT"
5454
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
57-
flash_block_sizes: {
58-
"block_q" : 256,
59-
"block_kv_compute" : 256,
60-
"block_kv" : 256,
61-
"block_q_dkv" : 256,
62-
"block_kv_dkv" : 256,
63-
"block_kv_dkv_compute" : 256,
64-
"block_q_dq" : 256,
65-
"block_kv_dq" : 256
66-
}
67-
68-
# Use the following flash_block_sizes on v6e (Trillium).
57+
flash_block_sizes: {}
58+
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
6959
# flash_block_sizes: {
70-
# "block_q" : 2176,
71-
# "block_kv_compute" : 2176,
72-
# "block_kv" : 2176,
73-
# "block_q_dkv" : 2176,
74-
# "block_kv_dkv" : 2176,
75-
# "block_kv_dkv_compute" : 2176,
76-
# "block_q_dq" : 2176,
77-
# "block_kv_dq" : 2176
60+
# "block_q" : 1536,
61+
# "block_kv_compute" : 1536,
62+
# "block_kv" : 1536,
63+
# "block_q_dkv" : 1536,
64+
# "block_kv_dkv" : 1536,
65+
# "block_kv_dkv_compute" : 1536,
66+
# "block_q_dq" : 1536,
67+
# "block_kv_dq" : 1536
7868
# }
7969
# GroupNorm groups
8070
norm_num_groups: 32

src/maxdiffusion/generate_flux.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Callable, List, Union, Sequence
1818
from absl import app
19+
from contextlib import ExitStack
1920
import functools
2021
import math
2122
import time
@@ -24,6 +25,7 @@
2425
import jax
2526
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
2627
import jax.numpy as jnp
28+
import flax.linen as nn
2729
from chex import Array
2830
from einops import rearrange
2931
from flax.linen import partitioning as nn_partitioning
@@ -39,6 +41,28 @@
3941
get_precision,
4042
setup_initial_state,
4143
)
44+
from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin
45+
46+
def maybe_load_flux_lora(config, lora_loader, params):
47+
def _noop_interceptor(next_fn, args, kwargs, context):
48+
return next_fn(*args, **kwargs)
49+
50+
lora_config = config.lora_config
51+
interceptors= [_noop_interceptor]
52+
if len(lora_config["lora_model_name_or_path"]) > 0:
53+
interceptors = []
54+
for i in range(len(lora_config["lora_model_name_or_path"])):
55+
params, rank, network_alphas = lora_loader.load_lora_weights(
56+
config,
57+
lora_config["lora_model_name_or_path"][i],
58+
weight_name=lora_config["weight_name"][i],
59+
params=params,
60+
adapter_name=lora_config["adapter_name"][i],
61+
)
62+
interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i])
63+
interceptors.append(interceptor)
64+
65+
return params, interceptors
4266

4367

4468
def unpack(x: Array, height: int, width: int) -> Array:
@@ -400,21 +424,29 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
400424

401425
# loads pretrained weights
402426
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
427+
params = {}
428+
params["transformer"] = transformer_params
429+
# maybe load lora and create interceptor
430+
lora_loader = FluxLoraLoaderMixin()
431+
params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params)
432+
transformer_params = params["transformer"]
403433
# create transformer state
404434
weights_init_fn = functools.partial(
405435
transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False
406436
)
407-
transformer_state, transformer_state_shardings = setup_initial_state(
408-
model=transformer,
409-
tx=None,
410-
config=config,
411-
mesh=mesh,
412-
weights_init_fn=weights_init_fn,
413-
model_params=None,
414-
training=False,
415-
)
416-
transformer_state = transformer_state.replace(params=transformer_params)
417-
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
437+
with ExitStack() as stack:
438+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
439+
transformer_state, transformer_state_shardings = setup_initial_state(
440+
model=transformer,
441+
tx=None,
442+
config=config,
443+
mesh=mesh,
444+
weights_init_fn=weights_init_fn,
445+
model_params=None,
446+
training=False,
447+
)
448+
transformer_state = transformer_state.replace(params=transformer_params)
449+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
418450
get_memory_allocations()
419451

420452
states = {}
@@ -444,17 +476,23 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
444476
out_shardings=None,
445477
)
446478
t0 = time.perf_counter()
447-
p_run_inference(states).block_until_ready()
479+
with ExitStack() as stack:
480+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
481+
p_run_inference(states).block_until_ready()
448482
t1 = time.perf_counter()
449483
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
450484

451485
t0 = time.perf_counter()
452-
imgs = p_run_inference(states).block_until_ready()
486+
with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"):
487+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
488+
imgs = p_run_inference(states).block_until_ready()
453489
t1 = time.perf_counter()
454490
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
455491

456492
t0 = time.perf_counter()
457-
imgs = p_run_inference(states).block_until_ready()
493+
with ExitStack() as stack:
494+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
495+
imgs = p_run_inference(states).block_until_ready()
458496
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
459497
t1 = time.perf_counter()
460498
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
# limitations under the License.
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
16+
from .flux_lora_pipeline import FluxLoraLoaderMixin
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Union, Dict
16+
from .lora_base import LoRABaseMixin
17+
from ..models.lora import LoRALinearLayer, BaseLoRALayer
18+
import jax.numpy as jnp
19+
from flax.traverse_util import flatten_dict, unflatten_dict
20+
from flax.core.frozen_dict import unfreeze
21+
from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax
22+
from huggingface_hub.utils import validate_hf_hub_args
23+
from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor)
24+
class FluxLoraLoaderMixin(LoRABaseMixin):
25+
26+
_lora_lodable_modules = ["transformer", "text_encoder"]
27+
28+
def load_lora_weights(
29+
self,
30+
config,
31+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]],
32+
params,
33+
adapter_name=None,
34+
**kwargs
35+
):
36+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
37+
38+
params, rank, network_alphas = self.load_lora(
39+
config,
40+
state_dict,
41+
params=params,
42+
adapter_name=adapter_name,
43+
)
44+
45+
return params, rank, network_alphas
46+
47+
def rename_for_interceptor(params_keys, network_alphas, adapter_name):
48+
new_params_keys = []
49+
new_network_alphas = {}
50+
lora_name = f"lora-{adapter_name}"
51+
for layer_lora in params_keys:
52+
if lora_name in layer_lora:
53+
new_layer_lora = layer_lora[: layer_lora.index(lora_name)]
54+
if new_layer_lora not in new_params_keys:
55+
new_params_keys.append(new_layer_lora)
56+
network_alpha = network_alphas[layer_lora]
57+
new_network_alphas[new_layer_lora] = network_alpha
58+
return new_params_keys, new_network_alphas
59+
60+
@classmethod
61+
def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
62+
network_alphas_for_interceptor = {}
63+
64+
transformer_keys = flatten_dict(params["transformer"]).keys()
65+
lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name)
66+
network_alphas_for_interceptor.update(transformer_alphas)
67+
68+
def _intercept(next_fn, args, kwargs, context):
69+
mod = context.module
70+
while mod is not None:
71+
if isinstance(mod, BaseLoRALayer):
72+
return next_fn(*args, **kwargs)
73+
mod = mod.parent
74+
h = next_fn(*args, **kwargs)
75+
if context.method_name == "__call__":
76+
module_path = context.module.path
77+
if module_path in lora_keys:
78+
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor, adapter_name)
79+
return lora_layer(h, *args, **kwargs)
80+
return h
81+
82+
return _intercept
83+
84+
@classmethod
85+
def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name):
86+
network_alpha = network_alphas.get(module_path, None)
87+
lora_module = LoRALinearLayer(
88+
out_features=module.features,
89+
rank=rank,
90+
network_alpha=network_alpha,
91+
dtype=module.dtype,
92+
weights_dtype=module.param_dtype,
93+
precision=module.precision,
94+
name=f"lora-{adapter_name}",
95+
)
96+
return lora_module
97+
98+
@classmethod
99+
@validate_hf_hub_args
100+
def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):
101+
102+
cache_dir = kwargs.pop("cache_dir", None)
103+
force_download = kwargs.pop("force_download", False)
104+
proxies = kwargs.pop("proxies", None)
105+
local_files_only = kwargs.pop("local_files_only", None)
106+
use_auth_token = kwargs.pop("use_auth_token", None)
107+
revision = kwargs.pop("revision", None)
108+
subfolder = kwargs.pop("subfolder", None)
109+
weight_name = kwargs.pop("weight_name", None)
110+
unet_config = kwargs.pop("unet_config", None)
111+
use_safetensors = kwargs.pop("use_safetensors", None)
112+
resume_download = kwargs.pop("resume_download", False)
113+
114+
allow_pickle = False
115+
if use_safetensors is None:
116+
use_safetensors = True
117+
allow_pickle = True
118+
119+
user_agent = {
120+
"file_type": "attn_procs_weights",
121+
"framework": "pytorch",
122+
}
123+
124+
state_dict = cls._fetch_state_dict(
125+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path,
126+
weight_name=weight_name,
127+
use_safetensors=use_safetensors,
128+
local_files_only=local_files_only,
129+
cache_dir=cache_dir,
130+
force_download=force_download,
131+
resume_download=resume_download,
132+
proxies=proxies,
133+
use_auth_token=use_auth_token,
134+
revision=revision,
135+
subfolder=subfolder,
136+
user_agent=user_agent,
137+
allow_pickle=allow_pickle,
138+
)
139+
140+
return state_dict
141+
142+
@classmethod
143+
def load_lora(cls, config, state_dict, params, adapter_name=None):
144+
params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name)
145+
return params, rank, network_alphas

src/maxdiffusion/loaders/lora_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name):
134134

135135
@classmethod
136136
def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
137-
# Only unet interceptor supported for now.
137+
138138
network_alphas_for_interceptor = {}
139139

140140
unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys()

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,51 @@ def create_flax_params_from_pytorch_state(
222222
renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value
223223
return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas
224224

225+
def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name):
226+
pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()}
227+
transformer_params = flatten_dict(unfreeze(params["transformer"]))
228+
network_alphas = {}
229+
rank = None
230+
for pt_key, tensor in pt_state_dict.items():
231+
renamed_pt_key = rename_key(pt_key)
232+
print("renamed_pt_key:", renamed_pt_key)
233+
renamed_pt_key = renamed_pt_key.replace("lora_unet_", "")
234+
renamed_pt_key = renamed_pt_key.replace("lora_down", f"lora-{adapter_name}.down")
235+
renamed_pt_key = renamed_pt_key.replace("lora_up", f"lora-{adapter_name}.up")
236+
237+
if "double_blocks" in renamed_pt_key:
238+
renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj")
239+
renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv")
240+
renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0")
241+
renamed_pt_key = renamed_pt_key.replace("_img_mlp_2", ".img_mlp.layers_2")
242+
renamed_pt_key = renamed_pt_key.replace("_img_mod_lin", ".img_norm1.lin")
243+
renamed_pt_key = renamed_pt_key.replace("_txt_attn_proj", ".attn.e_proj")
244+
renamed_pt_key = renamed_pt_key.replace("_txt_attn_qkv", ".attn.e_qkv")
245+
renamed_pt_key = renamed_pt_key.replace("_txt_mlp_0", ".txt_mlp.layers_0")
246+
renamed_pt_key = renamed_pt_key.replace("_txt_mlp_2", ".txt_mlp.layers_2")
247+
renamed_pt_key = renamed_pt_key.replace("_txt_mod_lin", ".txt_norm1.lin")
248+
elif "single_blocks" in renamed_pt_key:
249+
renamed_pt_key = renamed_pt_key.replace("_linear1", ".linear1")
250+
renamed_pt_key = renamed_pt_key.replace("_linear2", ".linear2")
251+
renamed_pt_key = renamed_pt_key.replace("_modulation_lin", ".norm.lin")
252+
253+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
254+
255+
pt_tuple_key = tuple(renamed_pt_key.split("."))
256+
if "alpha" in pt_tuple_key:
257+
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'down', 'kernel')
258+
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
259+
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'up', 'kernel')
260+
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
261+
else:
262+
if pt_tuple_key[-2] == "up":
263+
rank = tensor.shape[1]
264+
transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype)
265+
266+
params["transformer"] = unflatten_dict(transformer_params)
267+
268+
return params, rank, network_alphas
269+
225270

226271
def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas, adapter_name):
227272
# Step 1: Convert pytorch tensor to numpy

0 commit comments

Comments
 (0)