@@ -611,89 +611,79 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
611611
612612
613613def translate_wan_nnx_path_to_diffusers_lora (nnx_path_str , scan_layers = False ):
614- """
615- Translates WAN NNX path to Diffusers/LoRA keys.
616- Verified against wan_utils.py mappings.
617- """
618-
619- # --- 1. Embeddings (Exact Matches) ---
620- if nnx_path_str == 'condition_embedder.text_embedder.linear_1' :
621- return 'diffusion_model.text_embedding.0'
622- if nnx_path_str == 'condition_embedder.text_embedder.linear_2' :
623- return 'diffusion_model.text_embedding.2'
624- if nnx_path_str == 'condition_embedder.time_embedder.linear_1' :
625- return 'diffusion_model.time_embedding.0'
626- if nnx_path_str == 'condition_embedder.time_embedder.linear_2' :
627- return 'diffusion_model.time_embedding.2'
628- if nnx_path_str == 'condition_embedder.image_embedder.norm1.layer_norm' :
629- return 'diffusion_model.img_emb.proj.0'
630- if nnx_path_str == 'condition_embedder.image_embedder.ff.net_0' :
631- return 'diffusion_model.img_emb.proj.1'
632- if nnx_path_str == 'condition_embedder.image_embedder.ff.net_2' :
633- return 'diffusion_model.img_emb.proj.3'
634- if nnx_path_str == 'condition_embedder.image_embedder.norm2.layer_norm' :
635- return 'diffusion_model.img_emb.proj.4'
636- if nnx_path_str == 'patch_embedding' :
637- return 'diffusion_model.patch_embedding'
638- if nnx_path_str == 'proj_out' :
639- return 'diffusion_model.head.head'
640- if nnx_path_str == 'condition_embedder.time_proj' :
641- return 'diffusion_model.time_projection.1'
642-
643-
644-
645-
646- # --- 2. Map NNX Suffixes to LoRA Suffixes ---
647- suffix_map = {
648- # Self Attention (attn1)
649- "attn1.query" : "self_attn.q" ,
650- "attn1.key" : "self_attn.k" ,
651- "attn1.value" : "self_attn.v" ,
652- "attn1.proj_attn" : "self_attn.o" ,
653-
654- # Self Attention Norms (QK Norm)
655- "attn1.norm_q" : "self_attn.norm_q" ,
656- "attn1.norm_k" : "self_attn.norm_k" ,
657-
658- # Cross Attention (attn2)
659- "attn2.query" : "cross_attn.q" ,
660- "attn2.key" : "cross_attn.k" ,
661- "attn2.value" : "cross_attn.v" ,
662- "attn2.proj_attn" : "cross_attn.o" ,
663-
664- # Cross Attention Norms (QK Norm)
665- "attn2.norm_q" : "cross_attn.norm_q" ,
666- "attn2.norm_k" : "cross_attn.norm_k" ,
667-
668- # Cross Attention img
669- "attn2.add_k_proj" : "cross_attn.k_img" ,
670- "attn2.add_v_proj" : "cross_attn.v_img" ,
671- "attn2.norm_added_k" : "cross_attn.norm_k_img" ,
672-
673- # Feed Forward (ffn)
674- "ffn.act_fn.proj" : "ffn.0" , # Up proj
675- "ffn.proj_out" : "ffn.2" , # Down proj
676-
677- # Global Norms & Modulation
678- "norm2.layer_norm" : "norm3" ,
679- "scale_shift_table" : "modulation" ,
680- "proj_out" : "head.head"
681- }
682-
683- # --- 3. Translation Logic ---
684- if scan_layers :
685- # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
686- if nnx_path_str .startswith ("blocks." ):
687- inner_suffix = nnx_path_str [len ("blocks." ):]
688- if inner_suffix in suffix_map :
689- return f"diffusion_model.blocks.{{}}.{ suffix_map [inner_suffix ]} "
690- else :
691- # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
692- m = re .match (r"^blocks\.(\d+)\.(.+)$" , nnx_path_str )
693- if m :
694- idx , inner_suffix = m .group (1 ), m .group (2 )
695- if inner_suffix in suffix_map :
696- return f"diffusion_model.blocks.{ idx } .{ suffix_map [inner_suffix ]} "
614+ """
615+ Translates WAN NNX path to Diffusers/LoRA keys.
616+ Verified against wan_utils.py mappings.
617+ """
697618
698- return None
619+ # --- 1. Embeddings (Exact Matches) ---
620+ if nnx_path_str == "condition_embedder.text_embedder.linear_1" :
621+ return "diffusion_model.text_embedding.0"
622+ if nnx_path_str == "condition_embedder.text_embedder.linear_2" :
623+ return "diffusion_model.text_embedding.2"
624+ if nnx_path_str == "condition_embedder.time_embedder.linear_1" :
625+ return "diffusion_model.time_embedding.0"
626+ if nnx_path_str == "condition_embedder.time_embedder.linear_2" :
627+ return "diffusion_model.time_embedding.2"
628+ if nnx_path_str == "condition_embedder.image_embedder.norm1.layer_norm" :
629+ return "diffusion_model.img_emb.proj.0"
630+ if nnx_path_str == "condition_embedder.image_embedder.ff.net_0" :
631+ return "diffusion_model.img_emb.proj.1"
632+ if nnx_path_str == "condition_embedder.image_embedder.ff.net_2" :
633+ return "diffusion_model.img_emb.proj.3"
634+ if nnx_path_str == "condition_embedder.image_embedder.norm2.layer_norm" :
635+ return "diffusion_model.img_emb.proj.4"
636+ if nnx_path_str == "patch_embedding" :
637+ return "diffusion_model.patch_embedding"
638+ if nnx_path_str == "proj_out" :
639+ return "diffusion_model.head.head"
640+ if nnx_path_str == "condition_embedder.time_proj" :
641+ return "diffusion_model.time_projection.1"
642+
643+ # --- 2. Map NNX Suffixes to LoRA Suffixes ---
644+ suffix_map = {
645+ # Self Attention (attn1)
646+ "attn1.query" : "self_attn.q" ,
647+ "attn1.key" : "self_attn.k" ,
648+ "attn1.value" : "self_attn.v" ,
649+ "attn1.proj_attn" : "self_attn.o" ,
650+ # Self Attention Norms (QK Norm)
651+ "attn1.norm_q" : "self_attn.norm_q" ,
652+ "attn1.norm_k" : "self_attn.norm_k" ,
653+ # Cross Attention (attn2)
654+ "attn2.query" : "cross_attn.q" ,
655+ "attn2.key" : "cross_attn.k" ,
656+ "attn2.value" : "cross_attn.v" ,
657+ "attn2.proj_attn" : "cross_attn.o" ,
658+ # Cross Attention Norms (QK Norm)
659+ "attn2.norm_q" : "cross_attn.norm_q" ,
660+ "attn2.norm_k" : "cross_attn.norm_k" ,
661+ # Cross Attention img
662+ "attn2.add_k_proj" : "cross_attn.k_img" ,
663+ "attn2.add_v_proj" : "cross_attn.v_img" ,
664+ "attn2.norm_added_k" : "cross_attn.norm_k_img" ,
665+ # Feed Forward (ffn)
666+ "ffn.act_fn.proj" : "ffn.0" , # Up proj
667+ "ffn.proj_out" : "ffn.2" , # Down proj
668+ # Global Norms & Modulation
669+ "norm2.layer_norm" : "norm3" ,
670+ "scale_shift_table" : "modulation" ,
671+ "proj_out" : "head.head" ,
672+ }
699673
674+ # --- 3. Translation Logic ---
675+ if scan_layers :
676+ # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
677+ if nnx_path_str .startswith ("blocks." ):
678+ inner_suffix = nnx_path_str [len ("blocks." ) :]
679+ if inner_suffix in suffix_map :
680+ return f"diffusion_model.blocks.{{}}.{ suffix_map [inner_suffix ]} "
681+ else :
682+ # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
683+ m = re .match (r"^blocks\.(\d+)\.(.+)$" , nnx_path_str )
684+ if m :
685+ idx , inner_suffix = m .group (1 ), m .group (2 )
686+ if inner_suffix in suffix_map :
687+ return f"diffusion_model.blocks.{ idx } .{ suffix_map [inner_suffix ]} "
688+
689+ return None
0 commit comments