@@ -224,13 +224,34 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
224224
225225 assigned_count = 0
226226 for path , module in nnx .iter_graph (model ):
227- if not isinstance (module , (nnx .Linear , nnx .Conv )): continue
227+ if not isinstance (module , (nnx .Linear , nnx .Conv , nnx . LayerNorm , nnx . RMSNorm , nnx . Embed )): continue
228228
229229 nnx_path_str = "." .join (map (str , path ))
230230 lora_key = translate_fn (nnx_path_str ) if translate_fn else None
231231
232232 if lora_key and lora_key in lora_params :
233233 weights = lora_params [lora_key ]
234+
235+ # Handle Embeddings
236+ if isinstance (module , nnx .Embed ):
237+ if "diff" in weights and hasattr (module , 'embedding' ):
238+ module .embedding .value += np .array (weights ["diff" ]).reshape (module .embedding .shape ).astype (module .embedding .dtype )
239+ assigned_count += 1
240+ continue
241+ # Handle Norms
242+ elif isinstance (module , (nnx .LayerNorm , nnx .RMSNorm )):
243+ scale_diff = weights .get ("diff" , None )
244+ bias_diff = weights .get ("diff_b" , None )
245+ updated = False
246+ if scale_diff is not None and hasattr (module , 'scale' ) and module .scale is not None :
247+ module .scale .value += np .array (scale_diff ).reshape (module .scale .shape ).astype (module .scale .dtype )
248+ updated = True
249+ if bias_diff is not None and isinstance (module , nnx .LayerNorm ) and hasattr (module , 'bias' ) and module .bias is not None :
250+ module .bias .value += np .array (bias_diff ).reshape (module .bias .shape ).astype (module .bias .dtype )
251+ updated = True
252+ if updated :
253+ assigned_count += 1
254+ continue
234255
235256 # Prepare LoRA terms
236257 down_w , up_w , current_scale = None , None , None
@@ -298,14 +319,53 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
298319
299320 assigned_count = 0
300321 for path , module in nnx .iter_graph (model ):
301- if not isinstance (module , (nnx .Linear , nnx .Conv )): continue
322+ if not isinstance (module , (nnx .Linear , nnx .Conv , nnx . LayerNorm , nnx . RMSNorm , nnx . Embed )): continue
302323
303324 nnx_path_str = "." .join (map (str , path ))
304325 lora_key_template = translate_fn (nnx_path_str ) if translate_fn else None
305326
306327 if not lora_key_template :
307328 continue
308329
330+ # Handle Scanned Embeddings
331+ if isinstance (module , nnx .Embed ) and hasattr (module , 'embedding' ) and module .embedding .ndim > 2 :
332+ num_layers = module .embedding .shape [0 ]
333+ embed_diffs_to_add = np .zeros_like (module .embedding .value )
334+ updated = False
335+ for i in range (num_layers ):
336+ lora_key = lora_key_template .format (i )
337+ if lora_key in lora_params and "diff" in lora_params [lora_key ]:
338+ embed_diffs_to_add [i ] = np .array (lora_params [lora_key ]["diff" ]).reshape (module .embedding .shape [1 :])
339+ updated = True
340+ if updated :
341+ module .embedding .value += embed_diffs_to_add .astype (module .embedding .dtype )
342+ assigned_count += 1
343+ continue
344+
345+ # Handle Scanned Norms
346+ if isinstance (module , (nnx .LayerNorm , nnx .RMSNorm )) and hasattr (module , 'scale' ) and module .scale is not None and module .scale .ndim > 1 :
347+ num_layers = module .scale .shape [0 ]
348+ scale_diffs_to_add = np .zeros_like (module .scale .value )
349+ bias_diffs_to_add = np .zeros_like (module .bias .value ) if isinstance (module , nnx .LayerNorm ) and hasattr (module , 'bias' ) and module .bias is not None else None
350+ updated_scale , updated_bias = False , False
351+ for i in range (num_layers ):
352+ lora_key = lora_key_template .format (i )
353+ if lora_key in lora_params :
354+ weights = lora_params [lora_key ]
355+ if "diff" in weights :
356+ scale_diffs_to_add [i ] = np .array (weights ["diff" ]).reshape (module .scale .shape [1 :])
357+ updated_scale = True
358+ if "diff_b" in weights and bias_diffs_to_add is not None :
359+ bias_diffs_to_add [i ] = np .array (weights ["diff_b" ]).reshape (module .bias .shape [1 :])
360+ updated_bias = True
361+ if updated_scale :
362+ module .scale .value += scale_diffs_to_add .astype (module .scale .dtype )
363+ if updated_bias and bias_diffs_to_add is not None :
364+ module .bias .value += bias_diffs_to_add .astype (module .bias .dtype )
365+ if updated_scale or updated_bias :
366+ assigned_count += 1
367+ continue
368+
309369 is_linear = isinstance (module , nnx .Linear ) and module .kernel .ndim == 3
310370 is_conv = isinstance (module , nnx .Conv ) and module .kernel .ndim == 5
311371
0 commit comments