@@ -391,7 +391,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
391391 ait_up_keys = [k + ".lora_B.weight" for k in ait_keys ]
392392 if not is_sparse :
393393 # down_weight is copied to each split
394- ait_sd .update ({ k : down_weight for k in ait_down_keys } )
394+ ait_sd .update (dict . fromkeys ( ait_down_keys , down_weight ) )
395395
396396 # up_weight is split to each split
397397 ait_sd .update ({k : v for k , v in zip (ait_up_keys , torch .split (up_weight , dims , dim = 0 ))}) # noqa: C416
@@ -534,7 +534,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
534534 ait_up_keys = [k + ".lora_B.weight" for k in ait_keys ]
535535
536536 # down_weight is copied to each split
537- ait_sd .update ({ k : down_weight for k in ait_down_keys } )
537+ ait_sd .update (dict . fromkeys ( ait_down_keys , down_weight ) )
538538
539539 # up_weight is split to each split
540540 ait_sd .update ({k : v for k , v in zip (ait_up_keys , torch .split (up_weight , dims , dim = 0 ))}) # noqa: C416
@@ -615,7 +615,7 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
615615 Translates WAN NNX path to Diffusers/LoRA keys.
616616 Verified against wan_utils.py mappings.
617617 """
618-
618+
619619 # --- 1. Embeddings (Exact Matches) ---
620620 if nnx_path_str == 'condition_embedder.text_embedder.linear_1' :
621621 return 'diffusion_model.text_embedding.0'
@@ -640,11 +640,11 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
640640 "attn1.key" : "self_attn.k" ,
641641 "attn1.value" : "self_attn.v" ,
642642 "attn1.proj_attn" : "self_attn.o" ,
643-
643+
644644 # Self Attention Norms (QK Norm) - Added per your request
645645 "attn1.norm_q" : "self_attn.norm_q" ,
646646 "attn1.norm_k" : "self_attn.norm_k" ,
647-
647+
648648 # Cross Attention (attn2)
649649 "attn2.query" : "cross_attn.q" ,
650650 "attn2.key" : "cross_attn.k" ,
@@ -654,22 +654,22 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
654654 # Cross Attention Norms (QK Norm) - Added per your request
655655 "attn2.norm_q" : "cross_attn.norm_q" ,
656656 "attn2.norm_k" : "cross_attn.norm_k" ,
657-
657+
658658 # Feed Forward (ffn)
659659 "ffn.act_fn.proj" : "ffn.0" , # Up proj
660660 "ffn.proj_out" : "ffn.2" , # Down proj
661-
661+
662662 # Global Norms & Modulation
663- "norm2.layer_norm" : "norm3" ,
663+ "norm2.layer_norm" : "norm3" ,
664664 "scale_shift_table" : "modulation" ,
665- "proj_out" : "head.head"
665+ "proj_out" : "head.head"
666666 }
667667
668668 # --- 3. Translation Logic ---
669669 if scan_layers :
670670 # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
671671 if nnx_path_str .startswith ("blocks." ):
672- inner_suffix = nnx_path_str [len ("blocks." ):]
672+ inner_suffix = nnx_path_str [len ("blocks." ):]
673673 if inner_suffix in suffix_map :
674674 return f"diffusion_model.blocks.{{}}.{ suffix_map [inner_suffix ]} "
675675 else :
@@ -679,6 +679,6 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
679679 idx , inner_suffix = m .group (1 ), m .group (2 )
680680 if inner_suffix in suffix_map :
681681 return f"diffusion_model.blocks.{ idx } .{ suffix_map [inner_suffix ]} "
682-
682+
683683 return None
684684
0 commit comments