@@ -608,3 +608,92 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
608608 raise ValueError (f"`old_state_dict` should be at this point but has: { list (old_state_dict .keys ())} ." )
609609
610610 return new_state_dict
611+
612+
613+ def translate_wan_nnx_path_to_diffusers_lora (nnx_path_str , scan_layers = False ):
614+ """
615+ Translates WAN NNX path to Diffusers/LoRA keys.
616+ Verified against wan_utils.py mappings.
617+ """
618+
619+ # --- 1. Embeddings (Exact Matches) ---
620+ if nnx_path_str == 'condition_embedder.text_embedder.linear_1' :
621+ return 'diffusion_model.text_embedding.0'
622+ if nnx_path_str == 'condition_embedder.text_embedder.linear_2' :
623+ return 'diffusion_model.text_embedding.2'
624+ if nnx_path_str == 'condition_embedder.time_embedder.linear_1' :
625+ return 'diffusion_model.time_embedding.0'
626+ if nnx_path_str == 'condition_embedder.time_embedder.linear_2' :
627+ return 'diffusion_model.time_embedding.2'
628+ if nnx_path_str == 'condition_embedder.image_embedder.norm1.layer_norm' :
629+ return 'diffusion_model.img_emb.proj.0'
630+ if nnx_path_str == 'condition_embedder.image_embedder.ff.net_0' :
631+ return 'diffusion_model.img_emb.proj.1'
632+ if nnx_path_str == 'condition_embedder.image_embedder.ff.net_2' :
633+ return 'diffusion_model.img_emb.proj.3'
634+ if nnx_path_str == 'condition_embedder.image_embedder.norm2.layer_norm' :
635+ return 'diffusion_model.img_emb.proj.4'
636+ if nnx_path_str == 'patch_embedding' :
637+ return 'diffusion_model.patch_embedding'
638+ if nnx_path_str == 'proj_out' :
639+ return 'diffusion_model.head.head'
640+ if nnx_path_str == 'condition_embedder.time_proj' :
641+ return 'diffusion_model.time_projection.1'
642+
643+
644+
645+
646+ # --- 2. Map NNX Suffixes to LoRA Suffixes ---
647+ suffix_map = {
648+ # Self Attention (attn1)
649+ "attn1.query" : "self_attn.q" ,
650+ "attn1.key" : "self_attn.k" ,
651+ "attn1.value" : "self_attn.v" ,
652+ "attn1.proj_attn" : "self_attn.o" ,
653+
654+ # Self Attention Norms (QK Norm)
655+ "attn1.norm_q" : "self_attn.norm_q" ,
656+ "attn1.norm_k" : "self_attn.norm_k" ,
657+
658+ # Cross Attention (attn2)
659+ "attn2.query" : "cross_attn.q" ,
660+ "attn2.key" : "cross_attn.k" ,
661+ "attn2.value" : "cross_attn.v" ,
662+ "attn2.proj_attn" : "cross_attn.o" ,
663+
664+ # Cross Attention Norms (QK Norm)
665+ "attn2.norm_q" : "cross_attn.norm_q" ,
666+ "attn2.norm_k" : "cross_attn.norm_k" ,
667+
668+ # Cross Attention img
669+ "attn2.add_k_proj" : "cross_attn.k_img" ,
670+ "attn2.add_v_proj" : "cross_attn.v_img" ,
671+ "attn2.norm_added_k" : "cross_attn.norm_k_img" ,
672+
673+ # Feed Forward (ffn)
674+ "ffn.act_fn.proj" : "ffn.0" , # Up proj
675+ "ffn.proj_out" : "ffn.2" , # Down proj
676+
677+ # Global Norms & Modulation
678+ "norm2.layer_norm" : "norm3" ,
679+ "scale_shift_table" : "modulation" ,
680+ "proj_out" : "head.head"
681+ }
682+
683+ # --- 3. Translation Logic ---
684+ if scan_layers :
685+ # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
686+ if nnx_path_str .startswith ("blocks." ):
687+ inner_suffix = nnx_path_str [len ("blocks." ):]
688+ if inner_suffix in suffix_map :
689+ return f"diffusion_model.blocks.{{}}.{ suffix_map [inner_suffix ]} "
690+ else :
691+ # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
692+ m = re .match (r"^blocks\.(\d+)\.(.+)$" , nnx_path_str )
693+ if m :
694+ idx , inner_suffix = m .group (1 ), m .group (2 )
695+ if inner_suffix in suffix_map :
696+ return f"diffusion_model.blocks.{ idx } .{ suffix_map [inner_suffix ]} "
697+
698+ return None
699+
0 commit comments