Skip to content

Commit b1c63d7

Browse files
authored
Support other lora formats (#136)
1 parent 5058538 commit b1c63d7

4 files changed

Lines changed: 57 additions & 21 deletions

File tree

src/maxdiffusion/generate_sdxl.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def run(config):
225225
pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False
226226
)
227227

228+
# load unet params from orbax checkpoint
228229
unet_params = load_params_from_path(
229230
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "unet_state"
230231
)
@@ -253,14 +254,14 @@ def run(config):
253254
vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
254255
pipeline, params, checkpoint_item_name="vae_state", is_training=False
255256
)
256-
text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state(
257-
pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False
258-
)
259-
260-
text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state(
261-
pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False
262-
)
257+
with nn.intercept_methods(lora_interceptor):
258+
text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state(
259+
pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False
260+
)
263261

262+
text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state(
263+
pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False
264+
)
264265
states = {}
265266
state_shardings = {}
266267

src/maxdiffusion/loaders/lora_pipeline.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,21 +121,35 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):
121121

122122
def rename_for_interceptor(params_keys, network_alphas):
123123
new_params_keys = []
124+
new_network_alphas = {}
124125
for layer_lora in params_keys:
125126
if "lora" in layer_lora:
126127
new_layer_lora = layer_lora[: layer_lora.index("lora")]
127128
if new_layer_lora not in new_params_keys:
128129
new_params_keys.append(new_layer_lora)
129130
network_alpha = network_alphas[layer_lora]
130-
del network_alphas[layer_lora]
131-
network_alphas[new_layer_lora] = network_alpha
132-
return new_params_keys, network_alphas
131+
new_network_alphas[new_layer_lora] = network_alpha
132+
return new_params_keys, new_network_alphas
133133

134134
@classmethod
135135
def make_lora_interceptor(cls, params, rank, network_alphas):
136136
# Only unet interceptor supported for now.
137+
network_alphas_for_interceptor = {}
138+
137139
unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys()
138-
unet_lora_keys, network_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas)
140+
lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas)
141+
network_alphas_for_interceptor.update(unet_alphas)
142+
143+
text_encoder_keys = flax.traverse_util.flatten_dict(params["text_encoder"]).keys()
144+
text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas)
145+
lora_keys.extend(text_encoder_keys)
146+
network_alphas_for_interceptor.update(text_encoder_alphas)
147+
148+
if "text_encoder_2" in params.keys():
149+
text_encoder_2_keys = flax.traverse_util.flatten_dict(params["text_encoder_2"]).keys()
150+
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas)
151+
lora_keys.extend(text_encoder_2_keys)
152+
network_alphas_for_interceptor.update(text_encoder_2_alphas)
139153

140154
def _intercept(next_fn, args, kwargs, context):
141155
mod = context.module
@@ -146,8 +160,8 @@ def _intercept(next_fn, args, kwargs, context):
146160
h = next_fn(*args, **kwargs)
147161
if context.method_name == "__call__":
148162
module_path = context.module.path
149-
if module_path in unet_lora_keys:
150-
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas)
163+
if module_path in lora_keys:
164+
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor)
151165
return lora_layer(h, *args, **kwargs)
152166
return h
153167

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@ def _noop_interceptor(next_fn, args, kwargs, context):
4949
if len(lora_config["lora_model_name_or_path"]) > 0:
5050
# For now only first lora supported. In the future, they will be merged
5151
# before being loaded.
52+
# TODO - merge LoRAs here.
5253
params, rank, network_alphas = pipeline.load_lora_weights(
5354
lora_config["lora_model_name_or_path"][0],
5455
weight_name=lora_config["weight_name"][0],
5556
params=params,
5657
adapter_name=lora_config["adapter_name"][0],
58+
unet_config=pipeline.unet.config,
5759
)
5860
interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas)
5961

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,15 @@ def create_flax_params_from_pytorch_state(
137137
# Need to change some parameters name to match Flax names
138138
for pt_key, pt_tensor in pt_state_dict.items():
139139
network_alpha_value = get_network_alpha_value(pt_key, network_alphas)
140-
renamed_pt_key = rename_key(pt_key)
140+
141+
# rename text encoders fc1 lora layers.
142+
pt_key = pt_key.replace("lora_linear_layer", "lora")
143+
144+
# only rename the unet keys, text encoders are already correct.
145+
if "unet" in pt_key:
146+
renamed_pt_key = rename_key(pt_key)
147+
else:
148+
renamed_pt_key = pt_key
141149
pt_tuple_key = tuple(renamed_pt_key.split("."))
142150
# conv
143151
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
@@ -147,13 +155,24 @@ def create_flax_params_from_pytorch_state(
147155
flax_tensor = pt_tensor
148156
else:
149157
flax_key_list = [*pt_tuple_key]
150-
for rename_from, rename_to in (
151-
("to_k_lora", ("to_k", "lora")),
152-
("to_q_lora", ("to_q", "lora")),
153-
("to_v_lora", ("to_v", "lora")),
154-
("to_out_lora", ("to_out_0", "lora")),
155-
("weight", "kernel"),
156-
):
158+
if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key:
159+
rename_from_to = (
160+
("to_k_lora", ("k_proj", "lora")),
161+
("to_q_lora", ("q_proj", "lora")),
162+
("to_v_lora", ("v_proj", "lora")),
163+
("to_out_lora", ("out_proj", "lora")),
164+
("weight", "kernel"),
165+
)
166+
# the unet
167+
else:
168+
rename_from_to = (
169+
("to_k_lora", ("to_k", "lora")),
170+
("to_q_lora", ("to_q", "lora")),
171+
("to_v_lora", ("to_v", "lora")),
172+
("to_out_lora", ("to_out_0", "lora")),
173+
("weight", "kernel"),
174+
)
175+
for rename_from, rename_to in rename_from_to:
157176
tmp = []
158177
for s in flax_key_list:
159178
if s == rename_from:

0 commit comments

Comments
 (0)