Skip to content

Commit ff16ba6

Browse files
jfacevedo-googleksikiric
authored andcommitted
initial lora implementation for flux
1 parent 1f28cb5 commit ff16ba6

5 files changed

Lines changed: 122 additions & 48 deletions

File tree

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,17 @@ precision: "DEFAULT"
5454
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash
57-
flash_block_sizes: {
58-
"block_q" : 256,
59-
"block_kv_compute" : 256,
60-
"block_kv" : 256,
61-
"block_q_dkv" : 256,
62-
"block_kv_dkv" : 256,
63-
"block_kv_dkv_compute" : 256,
64-
"block_q_dq" : 256,
65-
"block_kv_dq" : 256
66-
}
67-
68-
# Use the following flash_block_sizes on v6e (Trillium).
57+
flash_block_sizes: {}
58+
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
6959
# flash_block_sizes: {
70-
# "block_q" : 2176,
71-
# "block_kv_compute" : 2176,
72-
# "block_kv" : 2176,
73-
# "block_q_dkv" : 2176,
74-
# "block_kv_dkv" : 2176,
75-
# "block_kv_dkv_compute" : 2176,
76-
# "block_q_dq" : 2176,
77-
# "block_kv_dq" : 2176
60+
# "block_q" : 1536,
61+
# "block_kv_compute" : 1536,
62+
# "block_kv" : 1536,
63+
# "block_q_dkv" : 1536,
64+
# "block_kv_dkv" : 1536,
65+
# "block_kv_dkv_compute" : 1536,
66+
# "block_q_dq" : 1536,
67+
# "block_kv_dq" : 1536
7868
# }
7969
# GroupNorm groups
8070
norm_num_groups: 32

src/maxdiffusion/generate_flux.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Callable, List, Union, Sequence
1818
from absl import app
19+
from contextlib import ExitStack
1920
import functools
2021
import math
2122
import time
@@ -24,6 +25,7 @@
2425
import jax
2526
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
2627
import jax.numpy as jnp
28+
import flax.linen as nn
2729
from chex import Array
2830
from einops import rearrange
2931
from flax.linen import partitioning as nn_partitioning
@@ -39,6 +41,28 @@
3941
get_precision,
4042
setup_initial_state,
4143
)
44+
from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin
45+
46+
def maybe_load_flux_lora(config, lora_loader, params):
47+
def _noop_interceptor(next_fn, args, kwargs, context):
48+
return next_fn(*args, **kwargs)
49+
50+
lora_config = config.lora_config
51+
interceptors= [_noop_interceptor]
52+
if len(lora_config["lora_model_name_or_path"]) > 0:
53+
interceptors = []
54+
for i in range(len(lora_config["lora_model_name_or_path"])):
55+
params, rank, network_alphas = lora_loader.load_lora_weights(
56+
config,
57+
lora_config["lora_model_name_or_path"][i],
58+
weight_name=lora_config["weight_name"][i],
59+
params=params,
60+
adapter_name=lora_config["adapter_name"][i],
61+
)
62+
interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i])
63+
interceptors.append(interceptor)
64+
65+
return params, interceptors
4266

4367

4468
def unpack(x: Array, height: int, width: int) -> Array:
@@ -400,21 +424,29 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
400424

401425
# loads pretrained weights
402426
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
427+
params = {}
428+
params["transformer"] = transformer_params
429+
# maybe load lora and create interceptor
430+
lora_loader = FluxLoraLoaderMixin()
431+
params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params)
432+
transformer_params = params["transformer"]
403433
# create transformer state
404434
weights_init_fn = functools.partial(
405435
transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False
406436
)
407-
transformer_state, transformer_state_shardings = setup_initial_state(
408-
model=transformer,
409-
tx=None,
410-
config=config,
411-
mesh=mesh,
412-
weights_init_fn=weights_init_fn,
413-
model_params=None,
414-
training=False,
415-
)
416-
transformer_state = transformer_state.replace(params=transformer_params)
417-
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
437+
with ExitStack() as stack:
438+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
439+
transformer_state, transformer_state_shardings = setup_initial_state(
440+
model=transformer,
441+
tx=None,
442+
config=config,
443+
mesh=mesh,
444+
weights_init_fn=weights_init_fn,
445+
model_params=None,
446+
training=False,
447+
)
448+
transformer_state = transformer_state.replace(params=transformer_params)
449+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
418450
get_memory_allocations()
419451

420452
states = {}
@@ -444,17 +476,23 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
444476
out_shardings=None,
445477
)
446478
t0 = time.perf_counter()
447-
p_run_inference(states).block_until_ready()
479+
with ExitStack() as stack:
480+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
481+
p_run_inference(states).block_until_ready()
448482
t1 = time.perf_counter()
449483
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
450484

451485
t0 = time.perf_counter()
452-
imgs = p_run_inference(states).block_until_ready()
486+
with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"):
487+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
488+
imgs = p_run_inference(states).block_until_ready()
453489
t1 = time.perf_counter()
454490
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
455491

456492
t0 = time.perf_counter()
457-
imgs = p_run_inference(states).block_until_ready()
493+
with ExitStack() as stack:
494+
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
495+
imgs = p_run_inference(states).block_until_ready()
458496
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
459497
t1 = time.perf_counter()
460498
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
16-
from .flux_lora_pipeline import FluxLoraLoaderMixin
16+
from .flux_lora_pipeline import FluxLoraLoaderMixin

src/maxdiffusion/loaders/flux_lora_pipeline.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,30 @@
1616
from .lora_base import LoRABaseMixin
1717
from ..models.lora import LoRALinearLayer, BaseLoRALayer
1818
import jax.numpy as jnp
19-
from flax.traverse_util import flatten_dict
19+
from flax.traverse_util import flatten_dict, unflatten_dict
20+
from flax.core.frozen_dict import unfreeze
2021
from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax
2122
from huggingface_hub.utils import validate_hf_hub_args
22-
23-
23+
from maxdiffusion.models.modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor)
2424
class FluxLoraLoaderMixin(LoRABaseMixin):
2525

2626
_lora_lodable_modules = ["transformer", "text_encoder"]
27-
27+
2828
def load_lora_weights(
2929
self,
3030
config,
3131
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]],
3232
params,
3333
adapter_name=None,
34-
**kwargs,
34+
**kwargs
3535
):
3636
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3737

3838
params, rank, network_alphas = self.load_lora(
39-
config,
40-
state_dict,
41-
params=params,
42-
adapter_name=adapter_name,
39+
config,
40+
state_dict,
41+
params=params,
42+
adapter_name=adapter_name,
4343
)
4444

4545
return params, rank, network_alphas
@@ -53,7 +53,7 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name):
5353
new_layer_lora = layer_lora[: layer_lora.index(lora_name)]
5454
if new_layer_lora not in new_params_keys:
5555
new_params_keys.append(new_layer_lora)
56-
network_alpha = network_alphas.get(layer_lora, None)
56+
network_alpha = network_alphas[layer_lora]
5757
new_network_alphas[new_layer_lora] = network_alpha
5858
return new_params_keys, new_network_alphas
5959

@@ -64,7 +64,7 @@ def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
6464
transformer_keys = flatten_dict(params["transformer"]).keys()
6565
lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name)
6666
network_alphas_for_interceptor.update(transformer_alphas)
67-
67+
6868
def _intercept(next_fn, args, kwargs, context):
6969
mod = context.module
7070
while mod is not None:
@@ -107,6 +107,7 @@ def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):
107107
revision = kwargs.pop("revision", None)
108108
subfolder = kwargs.pop("subfolder", None)
109109
weight_name = kwargs.pop("weight_name", None)
110+
unet_config = kwargs.pop("unet_config", None)
110111
use_safetensors = kwargs.pop("use_safetensors", None)
111112
resume_download = kwargs.pop("resume_download", False)
112113

@@ -137,8 +138,8 @@ def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):
137138
)
138139

139140
return state_dict
140-
141+
141142
@classmethod
142143
def load_lora(cls, config, state_dict, params, adapter_name=None):
143144
params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name)
144-
return params, rank, network_alphas
145+
return params, rank, network_alphas

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,51 @@ def create_flax_params_from_pytorch_state(
222222
renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value
223223
return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas
224224

225+
def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name):
226+
pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()}
227+
transformer_params = flatten_dict(unfreeze(params["transformer"]))
228+
network_alphas = {}
229+
rank = None
230+
for pt_key, tensor in pt_state_dict.items():
231+
renamed_pt_key = rename_key(pt_key)
232+
print("renamed_pt_key:", renamed_pt_key)
233+
renamed_pt_key = renamed_pt_key.replace("lora_unet_", "")
234+
renamed_pt_key = renamed_pt_key.replace("lora_down", f"lora-{adapter_name}.down")
235+
renamed_pt_key = renamed_pt_key.replace("lora_up", f"lora-{adapter_name}.up")
236+
237+
if "double_blocks" in renamed_pt_key:
238+
renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj")
239+
renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv")
240+
renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0")
241+
renamed_pt_key = renamed_pt_key.replace("_img_mlp_2", ".img_mlp.layers_2")
242+
renamed_pt_key = renamed_pt_key.replace("_img_mod_lin", ".img_norm1.lin")
243+
renamed_pt_key = renamed_pt_key.replace("_txt_attn_proj", ".attn.e_proj")
244+
renamed_pt_key = renamed_pt_key.replace("_txt_attn_qkv", ".attn.e_qkv")
245+
renamed_pt_key = renamed_pt_key.replace("_txt_mlp_0", ".txt_mlp.layers_0")
246+
renamed_pt_key = renamed_pt_key.replace("_txt_mlp_2", ".txt_mlp.layers_2")
247+
renamed_pt_key = renamed_pt_key.replace("_txt_mod_lin", ".txt_norm1.lin")
248+
elif "single_blocks" in renamed_pt_key:
249+
renamed_pt_key = renamed_pt_key.replace("_linear1", ".linear1")
250+
renamed_pt_key = renamed_pt_key.replace("_linear2", ".linear2")
251+
renamed_pt_key = renamed_pt_key.replace("_modulation_lin", ".norm.lin")
252+
253+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
254+
255+
pt_tuple_key = tuple(renamed_pt_key.split("."))
256+
if "alpha" in pt_tuple_key:
257+
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'down', 'kernel')
258+
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
259+
pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", 'up', 'kernel')
260+
network_alphas[tuple([*pt_tuple_key])] = tensor.item()
261+
else:
262+
if pt_tuple_key[-2] == "up":
263+
rank = tensor.shape[1]
264+
transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype)
265+
266+
params["transformer"] = unflatten_dict(transformer_params)
267+
268+
return params, rank, network_alphas
269+
225270

226271
def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name):
227272
pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()}

0 commit comments

Comments
 (0)