@@ -224,7 +224,9 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
224224
225225 assigned_count = 0
226226 for path , module in nnx .iter_graph (model ):
227- if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed )): continue
227+ if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed )):
228+ max_logging .log (f"Skipping non-supported module type: { module } " )
229+ continue
228230
229231 nnx_path_str = "." .join (map (str , path ))
230232 lora_key = translate_fn (nnx_path_str ) if translate_fn else None
@@ -319,7 +321,9 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
319321
320322 assigned_count = 0
321323 for path , module in nnx .iter_graph (model ):
322- if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed )): continue
324+ if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed )):
325+ max_logging .log (f"Skipping non-supported module type: { module } " )
326+ continue
323327
324328 nnx_path_str = "." .join (map (str , path ))
325329 lora_key_template = translate_fn (nnx_path_str ) if translate_fn else None
0 commit comments