@@ -324,8 +324,8 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
324324
325325 if lora_key_template :
326326 num_layers , in_features , out_features = module .kernel .shape
327- deltas = []
328- has_lora = False
327+ kernel_value_updated = module . kernel . value
328+ lora_found_in_module = False
329329 for i in range (num_layers ):
330330 lora_key = lora_key_template .format (i )
331331 if lora_key in lora_params and "down" in lora_params [lora_key ] and "up" in lora_params [lora_key ]:
@@ -335,15 +335,13 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
335335 alpha = weights .get ("alpha" , rank )
336336 current_scale = scale * alpha / rank
337337 delta_i = (down_w .T @ up_w .T ).reshape (in_features , out_features ) * current_scale
338- deltas .append (delta_i )
339- has_lora = True
340- else :
341- deltas .append (jnp .zeros ((in_features , out_features ), dtype = module .kernel .dtype ))
342-
343- if has_lora :
344- stacked_delta = jnp .stack (deltas , axis = 0 )
345- module .kernel .value += stacked_delta
338+ kernel_value_updated = kernel_value_updated .at [i ].add (delta_i )
339+ lora_found_in_module = True
340+
341+ if lora_found_in_module :
342+ module .kernel .value = kernel_value_updated
346343 assigned_count += 1
344+ max_logging .log (f"Merged LoRA into scanned layer { nnx_path_str } " )
347345 else :
348346 max_logging .log (f"Scanned layer { nnx_path_str } matched template but no LoRA weights found for any block." )
349347 else :
0 commit comments