Skip to content

Commit cc72fb6

Browse files
committed
Fix
1 parent b9ed4d8 commit cc72fb6

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,9 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
224224

225225
assigned_count = 0
226226
for path, module in nnx.iter_graph(model):
227-
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): continue
227+
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)):
228+
max_logging.log(f"Skipping non-supported module type: {module}")
229+
continue
228230

229231
nnx_path_str = ".".join(map(str, path))
230232
lora_key = translate_fn(nnx_path_str) if translate_fn else None
@@ -319,7 +321,9 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
319321

320322
assigned_count = 0
321323
for path, module in nnx.iter_graph(model):
322-
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): continue
324+
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)):
325+
max_logging.log(f"Skipping non-supported module type: {module}")
326+
continue
323327

324328
nnx_path_str = ".".join(map(str, path))
325329
lora_key_template = translate_fn(nnx_path_str) if translate_fn else None

0 commit comments

Comments
 (0)