@@ -608,3 +608,82 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
608608 raise ValueError (f"`old_state_dict` should be at this point but has: { list (old_state_dict .keys ())} ." )
609609
610610 return new_state_dict
611+
612+
613+ def translate_wan_nnx_path_to_diffusers_lora (nnx_path_str , scan_layers = False ):
614+ """
615+ Translates WAN NNX path to Diffusers/LoRA keys.
616+ Verified against wan_utils.py mappings.
617+ """
618+
619+ # --- 1. Embeddings (Exact Matches) ---
620+ if nnx_path_str == "condition_embedder.text_embedder.linear_1" :
621+ return "diffusion_model.text_embedding.0"
622+ if nnx_path_str == "condition_embedder.text_embedder.linear_2" :
623+ return "diffusion_model.text_embedding.2"
624+ if nnx_path_str == "condition_embedder.time_embedder.linear_1" :
625+ return "diffusion_model.time_embedding.0"
626+ if nnx_path_str == "condition_embedder.time_embedder.linear_2" :
627+ return "diffusion_model.time_embedding.2"
628+ if nnx_path_str == "condition_embedder.image_embedder.norm1.layer_norm" :
629+ return "diffusion_model.img_emb.proj.0"
630+ if nnx_path_str == "condition_embedder.image_embedder.ff.net_0" :
631+ return "diffusion_model.img_emb.proj.1"
632+ if nnx_path_str == "condition_embedder.image_embedder.ff.net_2" :
633+ return "diffusion_model.img_emb.proj.3"
634+ if nnx_path_str == "condition_embedder.image_embedder.norm2.layer_norm" :
635+ return "diffusion_model.img_emb.proj.4"
636+ if nnx_path_str == "patch_embedding" :
637+ return "diffusion_model.patch_embedding"
638+ if nnx_path_str == "proj_out" :
639+ return "diffusion_model.head.head"
640+ if nnx_path_str == "condition_embedder.time_proj" :
641+ return "diffusion_model.time_projection.1"
642+
643+ # --- 2. Map NNX Suffixes to LoRA Suffixes ---
644+ suffix_map = {
645+ # Self Attention (attn1)
646+ "attn1.query" : "self_attn.q" ,
647+ "attn1.key" : "self_attn.k" ,
648+ "attn1.value" : "self_attn.v" ,
649+ "attn1.proj_attn" : "self_attn.o" ,
650+ # Self Attention Norms (QK Norm)
651+ "attn1.norm_q" : "self_attn.norm_q" ,
652+ "attn1.norm_k" : "self_attn.norm_k" ,
653+ # Cross Attention (attn2)
654+ "attn2.query" : "cross_attn.q" ,
655+ "attn2.key" : "cross_attn.k" ,
656+ "attn2.value" : "cross_attn.v" ,
657+ "attn2.proj_attn" : "cross_attn.o" ,
658+ # Cross Attention Norms (QK Norm)
659+ "attn2.norm_q" : "cross_attn.norm_q" ,
660+ "attn2.norm_k" : "cross_attn.norm_k" ,
661+ # Cross Attention img
662+ "attn2.add_k_proj" : "cross_attn.k_img" ,
663+ "attn2.add_v_proj" : "cross_attn.v_img" ,
664+ "attn2.norm_added_k" : "cross_attn.norm_k_img" ,
665+ # Feed Forward (ffn)
666+ "ffn.act_fn.proj" : "ffn.0" , # Up proj
667+ "ffn.proj_out" : "ffn.2" , # Down proj
668+ # Global Norms & Modulation
669+ "norm2.layer_norm" : "norm3" ,
670+ "scale_shift_table" : "modulation" ,
671+ "proj_out" : "head.head" ,
672+ }
673+
674+ # --- 3. Translation Logic ---
675+ if scan_layers :
676+ # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
677+ if nnx_path_str .startswith ("blocks." ):
678+ inner_suffix = nnx_path_str [len ("blocks." ) :]
679+ if inner_suffix in suffix_map :
680+ return f"diffusion_model.blocks.{{}}.{ suffix_map [inner_suffix ]} "
681+ else :
682+ # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
683+ m = re .match (r"^blocks\.(\d+)\.(.+)$" , nnx_path_str )
684+ if m :
685+ idx , inner_suffix = m .group (1 ), m .group (2 )
686+ if inner_suffix in suffix_map :
687+ return f"diffusion_model.blocks.{ idx } .{ suffix_map [inner_suffix ]} "
688+
689+ return None
0 commit comments