@@ -612,16 +612,11 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
612612
613613def translate_wan_nnx_path_to_diffusers_lora (nnx_path_str , scan_layers = False ):
614614 """
615- Translates WAN NNX path like 'blocks.10.attn1.key' (scan_layers=False) or
616- 'blocks.attn1.key' (scan_layers=True) to
617- LoRA path like 'diffusion_model.blocks.10.self_attn.k' or
618- template 'diffusion_model.blocks.{}.self_attn.k'.
619- Returns None if no match.
615+ Translates WAN NNX path to Diffusers/LoRA keys.
616+ Verified against wan_utils.py mappings.
620617 """
621618
622- # Handle embeddings - exact paths
623- if nnx_path_str == "patch_embedding" :
624- return "diffusion_model.patch_embedding"
619+ # --- 1. Embeddings (Exact Matches) ---
625620 if nnx_path_str == 'condition_embedder.text_embedder.linear_1' :
626621 return 'diffusion_model.text_embedding.0'
627622 if nnx_path_str == 'condition_embedder.text_embedder.linear_2' :
@@ -630,46 +625,55 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
630625 return 'diffusion_model.time_embedding.0'
631626 if nnx_path_str == 'condition_embedder.time_embedder.linear_2' :
632627 return 'diffusion_model.time_embedding.2'
633-
634- # Translation for Attention and FFN layers
635- attn_ffn_map = {
636- "attn1.query" : "self_attn.q" ,
637- "attn1.key" : "self_attn.k" ,
638- "attn1.value" : "self_attn.v" ,
628+ if nnx_path_str == 'patch_embedding' :
629+ return 'diffusion_model.patch_embedding'
630+
631+ # --- 2. Map NNX Suffixes to LoRA Suffixes ---
632+ suffix_map = {
633+ # Self Attention (attn1)
634+ "attn1.query" : "self_attn.q" ,
635+ "attn1.key" : "self_attn.k" ,
636+ "attn1.value" : "self_attn.v" ,
639637 "attn1.proj_attn" : "self_attn.o" ,
640- "attn2.query" : "cross_attn.q" ,
641- "attn2.key" : "cross_attn.k" ,
642- "attn2.value" : "cross_attn.v" ,
638+
639+ # Self Attention Norms (QK Norm) - Added per your request
640+ "attn1.norm_q" : "self_attn.norm_q" ,
641+ "attn1.norm_k" : "self_attn.norm_k" ,
642+
643+ # Cross Attention (attn2)
644+ "attn2.query" : "cross_attn.q" ,
645+ "attn2.key" : "cross_attn.k" ,
646+ "attn2.value" : "cross_attn.v" ,
643647 "attn2.proj_attn" : "cross_attn.o" ,
644- "ffn.act_fn.proj" : "ffn.0" ,
645- "ffn.proj_out" : "ffn.2" ,
646- }
647- # Translation for Norm layers
648- norm_map = {
649- "norm3" : "norm3" ,
650- "attn1.norm_q" : "self_attn.norm_q" ,
651- "attn1.norm_k" : "self_attn.norm_k" ,
652- "attn2.norm_q" : "cross_attn.norm_q" ,
653- "attn2.norm_k" : "cross_attn.norm_k" ,
648+
649+ # Cross Attention Norms (QK Norm) - Added per your request
650+ "attn2.norm_q" : "cross_attn.norm_q" ,
651+ "attn2.norm_k" : "cross_attn.norm_k" ,
652+
653+ # Feed Forward (ffn)
654+ "ffn.act_fn.proj" : "ffn.0" , # Up proj
655+ "ffn.proj_out" : "ffn.2" , # Down proj
656+
657+ # Global Norms & Modulation
658+ "norm2.layer_norm" : "norm3" ,
659+ "scale_shift_table" : "modulation" ,
660+ "proj_out" : "head.head"
654661 }
655662
663+ # --- 3. Translation Logic ---
656664 if scan_layers :
657- # Handle scanned attn/ffn: blocks.attn1.query -> diffusion_model.blocks.{}.self_attn.q
658- for k , v in attn_ffn_map .items ():
659- if nnx_path_str == f"blocks.{ k } " :
660- return f"diffusion_model.blocks.{{}}.{ v } "
661- # Handle scanned norm: blocks.norm3 -> diffusion_model.blocks.{}.norm3
662- for k , v in norm_map .items ():
663- if nnx_path_str == f"blocks.{ k } " :
664- return f"diffusion_model.blocks.{{}}.{ v } "
665+ # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
666+ if nnx_path_str .startswith ("blocks." ):
667+ inner_suffix = nnx_path_str [len ("blocks." ):]
668+ if inner_suffix in suffix_map :
669+ return f"diffusion_model.blocks.{{}}.{ suffix_map [inner_suffix ]} "
665670 else :
666- # Handle non-scanned attn/ffn/norm: blocks.0.attn1.query -> diffusion_model.blocks.0.self_attn.q
671+ # Unscanned Pattern: " blocks.0.attn1.query" -> " diffusion_model.blocks.0.self_attn.q"
667672 m = re .match (r"^blocks\.(\d+)\.(.+)$" , nnx_path_str )
668673 if m :
669- idx , suffix = m .group (1 ), m .group (2 )
670- if suffix in attn_ffn_map :
671- return f"diffusion_model.blocks.{ idx } .{ attn_ffn_map [suffix ]} "
672- if suffix in norm_map :
673- return f"diffusion_model.blocks.{ idx } .{ norm_map [suffix ]} "
674+ idx , inner_suffix = m .group (1 ), m .group (2 )
675+ if inner_suffix in suffix_map :
676+ return f"diffusion_model.blocks.{ idx } .{ suffix_map [inner_suffix ]} "
674677
675678 return None
679+
0 commit comments