Skip to content

Commit 9abd7ce

Browse files
committed
Fix WAN2.1 lora
1 parent bb277ad commit 9abd7ce

2 files changed

Lines changed: 236 additions & 179 deletions

File tree

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -618,35 +618,58 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
618618
template 'diffusion_model.blocks.{}.self_attn.k'.
619619
Returns None if no match.
620620
"""
621-
translation_map = {
622-
"attn1": "self_attn",
623-
"attn2": "cross_attn",
624-
"query": "q",
625-
"key": "k",
626-
"value": "v",
627-
"proj_attn": "o",
621+
622+
# Handle embeddings - exact paths
623+
if nnx_path_str == "patch_embedding":
624+
return "diffusion_model.patch_embedding"
625+
if nnx_path_str == 'condition_embedder.text_embedder.linear_1':
626+
return 'diffusion_model.text_embedding.0'
627+
if nnx_path_str == 'condition_embedder.text_embedder.linear_2':
628+
return 'diffusion_model.text_embedding.2'
629+
if nnx_path_str == 'condition_embedder.time_embedder.linear_1':
630+
return 'diffusion_model.time_embedding.0'
631+
if nnx_path_str == 'condition_embedder.time_embedder.linear_2':
632+
return 'diffusion_model.time_embedding.2'
633+
634+
# Translation for Attention and FFN layers
635+
attn_ffn_map = {
636+
"attn1.query": "self_attn.q",
637+
"attn1.key": "self_attn.k",
638+
"attn1.value": "self_attn.v",
639+
"attn1.proj_attn": "self_attn.o",
640+
"attn2.query": "cross_attn.q",
641+
"attn2.key": "cross_attn.k",
642+
"attn2.value": "cross_attn.v",
643+
"attn2.proj_attn": "cross_attn.o",
628644
"ffn.act_fn.proj": "ffn.0",
629645
"ffn.proj_out": "ffn.2",
630646
}
631-
suffix_pattern = r"(attn[12]\.(?:query|key|value|proj_attn)|ffn\.(?:act_fn\.proj|proj_out))"
647+
# Translation for Norm layers
648+
norm_map = {
649+
"norm3": "norm3",
650+
"attn1.norm_q": "self_attn.norm_q",
651+
"attn1.norm_k": "self_attn.norm_k",
652+
"attn2.norm_q": "cross_attn.norm_q",
653+
"attn2.norm_k": "cross_attn.norm_k",
654+
}
655+
632656
if scan_layers:
633-
m = re.match(r"^blocks\." + suffix_pattern + "$", nnx_path_str)
634-
if not m:
635-
return None
636-
block_idx_str = "{}"
637-
suffix = m.group(1)
657+
# Handle scanned attn/ffn: blocks.attn1.query -> diffusion_model.blocks.{}.self_attn.q
658+
for k, v in attn_ffn_map.items():
659+
if nnx_path_str == f"blocks.{k}":
660+
return f"diffusion_model.blocks.{{}}.{v}"
661+
# Handle scanned norm: blocks.norm3 -> diffusion_model.blocks.{}.norm3
662+
for k, v in norm_map.items():
663+
if nnx_path_str == f"blocks.{k}":
664+
return f"diffusion_model.blocks.{{}}.{v}"
638665
else:
639-
m = re.match(r"^blocks\.(\d+)\." + suffix_pattern + "$", nnx_path_str)
640-
if not m:
641-
return None
642-
block_idx_str = m.group(1)
643-
suffix = m.group(2)
644-
645-
parts = suffix.split('.')
646-
if parts[0] == 'attn1' or parts[0] == 'attn2':
647-
lora_part1 = translation_map[parts[0]]
648-
lora_part2 = translation_map[parts[1]]
649-
return f"diffusion_model.blocks.{block_idx_str}.{lora_part1}.{lora_part2}"
650-
elif suffix in translation_map:
651-
return f"diffusion_model.blocks.{block_idx_str}.{translation_map[suffix]}"
666+
# Handle non-scanned attn/ffn/norm: blocks.0.attn1.query -> diffusion_model.blocks.0.self_attn.q
667+
m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str)
668+
if m:
669+
idx, suffix = m.group(1), m.group(2)
670+
if suffix in attn_ffn_map:
671+
return f"diffusion_model.blocks.{idx}.{attn_ffn_map[suffix]}"
672+
if suffix in norm_map:
673+
return f"diffusion_model.blocks.{idx}.{norm_map[suffix]}"
674+
652675
return None

0 commit comments

Comments
 (0)