Skip to content

Commit 85e3015

Browse files
committed
Fix
1 parent 43fccc1 commit 85e3015

1 file changed

Lines changed: 19 additions & 9 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,17 +207,27 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
207207
"""
208208
lora_params = {}
209209
for k, v in state_dict.items():
210+
# Try matching diffusers rename format: "some.thing_lora.down.weight"
210211
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
211-
if not m:
212-
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
213-
214212
if m:
215-
module_path_str, weight_type = m.group(1), m.group(2)
216-
if module_path_str not in lora_params:
217-
lora_params[module_path_str] = {}
218-
lora_params[module_path_str][weight_type] = jnp.array(v)
213+
module_path_str, weight_type = m.group(1), m.group(2)
219214
else:
220-
max_logging.log(f"Could not parse LoRA key: {k}")
215+
# Try matching diffusers format: "some.thing.lora.down.weight"
216+
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
217+
if m:
218+
module_path_str, weight_type = m.group(1), m.group(2)
219+
else:
220+
# Try matching kohya/lightning format: "some.thing.lora_down.weight"
221+
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
222+
if m:
223+
module_path_str, weight_type = m.group(1), m.group(2).replace("lora_", "")
224+
else:
225+
max_logging.debug(f"Could not parse LoRA key: {k}")
226+
continue
227+
228+
if module_path_str not in lora_params:
229+
lora_params[module_path_str] = {}
230+
lora_params[module_path_str][weight_type] = jnp.array(v)
221231

222232
assigned_count = 0
223233
for path, module in nnx.iter_graph(model):
@@ -257,5 +267,5 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
257267
f"with kernel_size {module.kernel_size} > 1 is not supported."
258268
)
259269
else:
260-
max_logging.log(f"LoRA weights for {matched_key} incomplete.")
270+
max_logging.warning(f"LoRA weights for {matched_key} incomplete.")
261271
max_logging.log(f"Merged weights into {assigned_count} layers in {type(model).__name__}.")

0 commit comments

Comments
 (0)