Skip to content

Commit 1872a01

Browse files
committed
test
1 parent 5f07ad0 commit 1872a01

2 files changed

Lines changed: 3 additions & 8 deletions

File tree

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
651651
return "diffusion_model.patch_embedding"
652652
if nnx_path_str == "proj_out":
653653
return "diffusion_model.head.head"
654-
if nnx_path_str == "adaln_scale_shift_table":
654+
if nnx_path_str == "scale_shift_table":
655655
return "diffusion_model.head.modulation"
656656
if nnx_path_str == "condition_embedder.time_proj":
657657
return "diffusion_model.time_projection.1"

src/maxdiffusion/models/lora_nnx.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,12 +295,6 @@ def merge_lora_for_scanned(
295295
nnx_path_str = ".".join(map(str, path))
296296
lora_key_template = translate_fn(nnx_path_str) if translate_fn else None
297297

298-
if 'adaln_scale_shift_table' in nnx_path_str:
299-
max_logging.log(f"adaln_scale_shift_table: {nnx_path_str}")
300-
max_logging.log(f"lora_key_template: {lora_key_template}")
301-
max_logging.log(f"Module ndim: {module.ndim}")
302-
max_logging.log(f"Module Type: {type(module)}")
303-
304298
if not lora_key_template:
305299
continue
306300

@@ -315,7 +309,8 @@ def merge_lora_for_scanned(
315309
elif isinstance(module, nnx.Conv):
316310
is_scanned = module.kernel.ndim == 5
317311
elif isinstance(module, nnx.Param):
318-
is_scanned = module.ndim > 1 # Heuristic for scanned Param
312+
# Use template format to disambiguate: if template has {}, then it is scanned.
313+
is_scanned = ("{}" in lora_key_template)
319314

320315
if not is_scanned:
321316
lora_key = lora_key_template

0 commit comments

Comments
 (0)