Skip to content

Commit 1b31b3c

Browse files
committed
Fix
1 parent e25067b commit 1b31b3c

1 file changed

Lines changed: 8 additions & 10 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,8 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
324324

325325
if lora_key_template:
326326
num_layers, in_features, out_features = module.kernel.shape
327-
deltas = []
328-
has_lora = False
327+
kernel_value_updated = module.kernel.value
328+
lora_found_in_module = False
329329
for i in range(num_layers):
330330
lora_key = lora_key_template.format(i)
331331
if lora_key in lora_params and "down" in lora_params[lora_key] and "up" in lora_params[lora_key]:
@@ -335,15 +335,13 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
335335
alpha = weights.get("alpha", rank)
336336
current_scale = scale * alpha / rank
337337
delta_i = (down_w.T @ up_w.T).reshape(in_features, out_features) * current_scale
338-
deltas.append(delta_i)
339-
has_lora = True
340-
else:
341-
deltas.append(jnp.zeros((in_features, out_features), dtype=module.kernel.dtype))
342-
343-
if has_lora:
344-
stacked_delta = jnp.stack(deltas, axis=0)
345-
module.kernel.value += stacked_delta
338+
kernel_value_updated = kernel_value_updated.at[i].add(delta_i)
339+
lora_found_in_module = True
340+
341+
if lora_found_in_module:
342+
module.kernel.value = kernel_value_updated
346343
assigned_count += 1
344+
max_logging.log(f"Merged LoRA into scanned layer {nnx_path_str}")
347345
else:
348346
max_logging.log(f"Scanned layer {nnx_path_str} matched template but no LoRA weights found for any block.")
349347
else:

0 commit comments

Comments
 (0)