@@ -169,6 +169,10 @@ def _merge_lora_layer(module, weights, scale):
169169 if bias_diff is not None and isinstance (module , nnx .LayerNorm ) and hasattr (module , "bias" ) and module .bias is not None :
170170 module .bias .value += np .array (bias_diff ).reshape (module .bias .shape ).astype (module .bias .dtype )
171171 updated = True
172+ elif isinstance (module , nnx .Param ):
173+ if "diff" in weights :
174+ module .value += np .array (weights ["diff" ]).reshape (module .shape ).astype (module .dtype )
175+ updated = True
172176 elif isinstance (module , (nnx .Linear , nnx .Conv )):
173177 # Prepare LoRA terms
174178 down_w , up_w , current_scale = None , None , None
@@ -252,7 +256,7 @@ def merge_lora(model: nnx.Module, state_dict: dict, rank: int, scale: float, tra
252256
253257 assigned_count = 0
254258 for path , module in nnx .iter_graph (model ):
255- if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed )):
259+ if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed , nnx . Param )):
256260 continue
257261
258262 nnx_path_str = "." .join (map (str , path ))
@@ -285,23 +289,19 @@ def merge_lora_for_scanned(
285289
286290 assigned_count = 0
287291 for path , module in nnx .iter_graph (model ):
292+ if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed , nnx .Param )):
293+ continue
288294
289295 nnx_path_str = "." .join (map (str , path ))
290296 lora_key_template = translate_fn (nnx_path_str ) if translate_fn else None
291- max_logging . log ( f"nnx_path_str: { nnx_path_str } " )
292- if 'scale_shift_table ' in nnx_path_str :
297+
298+ if 'adaln_scale_shift_table ' in nnx_path_str :
293299 max_logging .log (f"adaln_scale_shift_table: { nnx_path_str } " )
294300 max_logging .log (f"lora_key_template: { lora_key_template } " )
295- max_logging .log (f"Module: { module } " )
301+ max_logging .log (f"Module ndim : { module . ndim } " )
296302 max_logging .log (f"Module Type: { type (module )} " )
297303 continue
298304
299- if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed )):
300- continue
301-
302- nnx_path_str = "." .join (map (str , path ))
303- lora_key_template = translate_fn (nnx_path_str ) if translate_fn else None
304-
305305 if not lora_key_template :
306306 continue
307307
@@ -315,6 +315,8 @@ def merge_lora_for_scanned(
315315 is_scanned = module .kernel .ndim == 3
316316 elif isinstance (module , nnx .Conv ):
317317 is_scanned = module .kernel .ndim == 5
318+ elif isinstance (module , nnx .Param ):
319+ is_scanned = module .ndim > 1 # Heuristic for scanned Param
318320
319321 if not is_scanned :
320322 lora_key = lora_key_template
@@ -366,6 +368,20 @@ def merge_lora_for_scanned(
366368 module .bias .value += bias_diffs_to_add .astype (module .bias .dtype )
367369 if updated_scale or updated_bias :
368370 assigned_count += 1
371+ elif isinstance (module , nnx .Param ):
372+ num_layers = module .shape [0 ]
373+ param_diffs_to_add = np .zeros_like (module .value )
374+ updated = False
375+ for i in range (num_layers ):
376+ lora_key = lora_key_template .format (i )
377+ if lora_key in lora_params :
378+ matched_keys .add (lora_key )
379+ if "diff" in lora_params [lora_key ]:
380+ param_diffs_to_add [i ] = np .array (lora_params [lora_key ]["diff" ]).reshape (module .shape [1 :])
381+ updated = True
382+ if updated :
383+ module .value += param_diffs_to_add .astype (module .dtype )
384+ assigned_count += 1
369385 elif isinstance (module , (nnx .Linear , nnx .Conv )):
370386 is_linear = isinstance (module , nnx .Linear )
371387 is_conv = isinstance (module , nnx .Conv )
0 commit comments