@@ -618,35 +618,58 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
618618 template 'diffusion_model.blocks.{}.self_attn.k'.
619619 Returns None if no match.
620620 """
621- translation_map = {
622- "attn1" : "self_attn" ,
623- "attn2" : "cross_attn" ,
624- "query" : "q" ,
625- "key" : "k" ,
626- "value" : "v" ,
627- "proj_attn" : "o" ,
621+
622+ # Handle embeddings - exact paths
623+ if nnx_path_str == "patch_embedding" :
624+ return "diffusion_model.patch_embedding"
625+ if nnx_path_str == 'condition_embedder.text_embedder.linear_1' :
626+ return 'diffusion_model.text_embedding.0'
627+ if nnx_path_str == 'condition_embedder.text_embedder.linear_2' :
628+ return 'diffusion_model.text_embedding.2'
629+ if nnx_path_str == 'condition_embedder.time_embedder.linear_1' :
630+ return 'diffusion_model.time_embedding.0'
631+ if nnx_path_str == 'condition_embedder.time_embedder.linear_2' :
632+ return 'diffusion_model.time_embedding.2'
633+
634+ # Translation for Attention and FFN layers
635+ attn_ffn_map = {
636+ "attn1.query" : "self_attn.q" ,
637+ "attn1.key" : "self_attn.k" ,
638+ "attn1.value" : "self_attn.v" ,
639+ "attn1.proj_attn" : "self_attn.o" ,
640+ "attn2.query" : "cross_attn.q" ,
641+ "attn2.key" : "cross_attn.k" ,
642+ "attn2.value" : "cross_attn.v" ,
643+ "attn2.proj_attn" : "cross_attn.o" ,
628644 "ffn.act_fn.proj" : "ffn.0" ,
629645 "ffn.proj_out" : "ffn.2" ,
630646 }
631- suffix_pattern = r"(attn[12]\.(?:query|key|value|proj_attn)|ffn\.(?:act_fn\.proj|proj_out))"
647+ # Translation for Norm layers
648+ norm_map = {
649+ "norm3" : "norm3" ,
650+ "attn1.norm_q" : "self_attn.norm_q" ,
651+ "attn1.norm_k" : "self_attn.norm_k" ,
652+ "attn2.norm_q" : "cross_attn.norm_q" ,
653+ "attn2.norm_k" : "cross_attn.norm_k" ,
654+ }
655+
632656 if scan_layers :
633- m = re .match (r"^blocks\." + suffix_pattern + "$" , nnx_path_str )
634- if not m :
635- return None
636- block_idx_str = "{}"
637- suffix = m .group (1 )
657+ # Handle scanned attn/ffn: blocks.attn1.query -> diffusion_model.blocks.{}.self_attn.q
658+ for k , v in attn_ffn_map .items ():
659+ if nnx_path_str == f"blocks.{ k } " :
660+ return f"diffusion_model.blocks.{{}}.{ v } "
661+ # Handle scanned norm: blocks.norm3 -> diffusion_model.blocks.{}.norm3
662+ for k , v in norm_map .items ():
663+ if nnx_path_str == f"blocks.{ k } " :
664+ return f"diffusion_model.blocks.{{}}.{ v } "
638665 else :
639- m = re .match (r"^blocks\.(\d+)\." + suffix_pattern + "$" , nnx_path_str )
640- if not m :
641- return None
642- block_idx_str = m .group (1 )
643- suffix = m .group (2 )
644-
645- parts = suffix .split ('.' )
646- if parts [0 ] == 'attn1' or parts [0 ] == 'attn2' :
647- lora_part1 = translation_map [parts [0 ]]
648- lora_part2 = translation_map [parts [1 ]]
649- return f"diffusion_model.blocks.{ block_idx_str } .{ lora_part1 } .{ lora_part2 } "
650- elif suffix in translation_map :
651- return f"diffusion_model.blocks.{ block_idx_str } .{ translation_map [suffix ]} "
666+ # Handle non-scanned attn/ffn/norm: blocks.0.attn1.query -> diffusion_model.blocks.0.self_attn.q
667+ m = re .match (r"^blocks\.(\d+)\.(.+)$" , nnx_path_str )
668+ if m :
669+ idx , suffix = m .group (1 ), m .group (2 )
670+ if suffix in attn_ffn_map :
671+ return f"diffusion_model.blocks.{ idx } .{ attn_ffn_map [suffix ]} "
672+ if suffix in norm_map :
673+ return f"diffusion_model.blocks.{ idx } .{ norm_map [suffix ]} "
674+
652675 return None
0 commit comments