Skip to content

Commit 266ae13

Browse files
committed
Fix
1 parent 18af7d4 commit 266ae13

1 file changed

Lines changed: 34 additions & 2 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,46 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
341341
if lora_found_in_module:
342342
module.kernel.value = kernel_value_updated
343343
assigned_count += 1
344-
max_logging.log(f"Merged LoRA into scanned layer {nnx_path_str}")
345344
else:
346345
max_logging.log(f"Scanned layer {nnx_path_str} matched template but no LoRA weights found for any block.")
347346
else:
348347
max_logging.log(f"Scanned NNX layer '{nnx_path_str}' could not be translated to a LoRA key template.")
349348

350349
# Handle scanned Conv layers (ndim=5)
351350
elif isinstance(module, nnx.Conv) and module.kernel.ndim == 5:
352-
max_logging.log(f"Merging LoRA into scanned Conv layers not implemented: {nnx_path_str}")
351+
if module.kernel_size != (1, 1):
352+
max_logging.warn(f"Skipping merge for scanned Conv layer {nnx_path_str} with kernel size {module.kernel_size}, only 1x1 is supported for merging.")
353+
continue
354+
355+
lora_key_template = translate_fn(nnx_path_str) if translate_fn else None
356+
if lora_key_template:
357+
num_layers, _, _, in_features, out_features = module.kernel.shape
358+
kernel_value_updated = module.kernel.value
359+
lora_found_in_module = False
360+
for i in range(num_layers):
361+
lora_key = lora_key_template.format(i)
362+
if lora_key in lora_params and "down" in lora_params[lora_key] and "up" in lora_params[lora_key]:
363+
weights = lora_params[lora_key]
364+
down_w, up_w = weights["down"], weights["up"]
365+
366+
if down_w.ndim == 4:
367+
down_w = jnp.squeeze(down_w)
368+
if up_w.ndim == 4:
369+
up_w = jnp.squeeze(up_w)
370+
371+
rank = down_w.shape[0]
372+
alpha = weights.get("alpha", rank)
373+
current_scale = scale * alpha / rank
374+
delta_i = (down_w.T @ up_w.T).reshape(1, 1, in_features, out_features) * current_scale
375+
kernel_value_updated = kernel_value_updated.at[i].add(delta_i)
376+
lora_found_in_module = True
377+
378+
if lora_found_in_module:
379+
module.kernel.value = kernel_value_updated
380+
assigned_count += 1
381+
else:
382+
max_logging.log(f"Scanned 1x1 Conv layer {nnx_path_str} matched template but no LoRA weights found for any block.")
383+
else:
384+
max_logging.log(f"Scanned 1x1 Conv layer '{nnx_path_str}' could not be translated to a LoRA key template.")
353385

354386
max_logging.log(f"Merged weights into {assigned_count} scanned layers in {type(model).__name__}.")

0 commit comments

Comments
 (0)