@@ -341,14 +341,46 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
341341 if lora_found_in_module :
342342 module .kernel .value = kernel_value_updated
343343 assigned_count += 1
344- max_logging .log (f"Merged LoRA into scanned layer { nnx_path_str } " )
345344 else :
346345 max_logging .log (f"Scanned layer { nnx_path_str } matched template but no LoRA weights found for any block." )
347346 else :
348347 max_logging .log (f"Scanned NNX layer '{ nnx_path_str } ' could not be translated to a LoRA key template." )
349348
350349 # Handle scanned Conv layers (ndim=5)
351350 elif isinstance (module , nnx .Conv ) and module .kernel .ndim == 5 :
352- max_logging .log (f"Merging LoRA into scanned Conv layers not implemented: { nnx_path_str } " )
351+ if module .kernel_size != (1 , 1 ):
352+ max_logging .warn (f"Skipping merge for scanned Conv layer { nnx_path_str } with kernel size { module .kernel_size } , only 1x1 is supported for merging." )
353+ continue
354+
355+ lora_key_template = translate_fn (nnx_path_str ) if translate_fn else None
356+ if lora_key_template :
357+ num_layers , _ , _ , in_features , out_features = module .kernel .shape
358+ kernel_value_updated = module .kernel .value
359+ lora_found_in_module = False
360+ for i in range (num_layers ):
361+ lora_key = lora_key_template .format (i )
362+ if lora_key in lora_params and "down" in lora_params [lora_key ] and "up" in lora_params [lora_key ]:
363+ weights = lora_params [lora_key ]
364+ down_w , up_w = weights ["down" ], weights ["up" ]
365+
366+ if down_w .ndim == 4 :
367+ down_w = jnp .squeeze (down_w )
368+ if up_w .ndim == 4 :
369+ up_w = jnp .squeeze (up_w )
370+
371+ rank = down_w .shape [0 ]
372+ alpha = weights .get ("alpha" , rank )
373+ current_scale = scale * alpha / rank
374+ delta_i = (down_w .T @ up_w .T ).reshape (1 , 1 , in_features , out_features ) * current_scale
375+ kernel_value_updated = kernel_value_updated .at [i ].add (delta_i )
376+ lora_found_in_module = True
377+
378+ if lora_found_in_module :
379+ module .kernel .value = kernel_value_updated
380+ assigned_count += 1
381+ else :
382+ max_logging .log (f"Scanned 1x1 Conv layer { nnx_path_str } matched template but no LoRA weights found for any block." )
383+ else :
384+ max_logging .log (f"Scanned 1x1 Conv layer '{ nnx_path_str } ' could not be translated to a LoRA key template." )
353385
354386 max_logging .log (f"Merged weights into { assigned_count } scanned layers in { type (model ).__name__ } ." )
0 commit comments