Skip to content

Commit 5e1d2b3

Browse files
committed
Formatting through pyink
1 parent e1b7221 commit 5e1d2b3

3 files changed

Lines changed: 591 additions & 578 deletions

File tree

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 74 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -611,89 +611,79 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
611611

612612

613613
def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
614-
"""
615-
Translates WAN NNX path to Diffusers/LoRA keys.
616-
Verified against wan_utils.py mappings.
617-
"""
618-
619-
# --- 1. Embeddings (Exact Matches) ---
620-
if nnx_path_str == 'condition_embedder.text_embedder.linear_1':
621-
return 'diffusion_model.text_embedding.0'
622-
if nnx_path_str == 'condition_embedder.text_embedder.linear_2':
623-
return 'diffusion_model.text_embedding.2'
624-
if nnx_path_str == 'condition_embedder.time_embedder.linear_1':
625-
return 'diffusion_model.time_embedding.0'
626-
if nnx_path_str == 'condition_embedder.time_embedder.linear_2':
627-
return 'diffusion_model.time_embedding.2'
628-
if nnx_path_str == 'condition_embedder.image_embedder.norm1.layer_norm':
629-
return 'diffusion_model.img_emb.proj.0'
630-
if nnx_path_str == 'condition_embedder.image_embedder.ff.net_0':
631-
return 'diffusion_model.img_emb.proj.1'
632-
if nnx_path_str == 'condition_embedder.image_embedder.ff.net_2':
633-
return 'diffusion_model.img_emb.proj.3'
634-
if nnx_path_str == 'condition_embedder.image_embedder.norm2.layer_norm':
635-
return 'diffusion_model.img_emb.proj.4'
636-
if nnx_path_str == 'patch_embedding':
637-
return 'diffusion_model.patch_embedding'
638-
if nnx_path_str == 'proj_out':
639-
return 'diffusion_model.head.head'
640-
if nnx_path_str == 'condition_embedder.time_proj':
641-
return 'diffusion_model.time_projection.1'
642-
643-
644-
645-
646-
# --- 2. Map NNX Suffixes to LoRA Suffixes ---
647-
suffix_map = {
648-
# Self Attention (attn1)
649-
"attn1.query": "self_attn.q",
650-
"attn1.key": "self_attn.k",
651-
"attn1.value": "self_attn.v",
652-
"attn1.proj_attn": "self_attn.o",
653-
654-
# Self Attention Norms (QK Norm)
655-
"attn1.norm_q": "self_attn.norm_q",
656-
"attn1.norm_k": "self_attn.norm_k",
657-
658-
# Cross Attention (attn2)
659-
"attn2.query": "cross_attn.q",
660-
"attn2.key": "cross_attn.k",
661-
"attn2.value": "cross_attn.v",
662-
"attn2.proj_attn": "cross_attn.o",
663-
664-
# Cross Attention Norms (QK Norm)
665-
"attn2.norm_q": "cross_attn.norm_q",
666-
"attn2.norm_k": "cross_attn.norm_k",
667-
668-
# Cross Attention img
669-
"attn2.add_k_proj": "cross_attn.k_img",
670-
"attn2.add_v_proj": "cross_attn.v_img",
671-
"attn2.norm_added_k": "cross_attn.norm_k_img",
672-
673-
# Feed Forward (ffn)
674-
"ffn.act_fn.proj": "ffn.0", # Up proj
675-
"ffn.proj_out": "ffn.2", # Down proj
676-
677-
# Global Norms & Modulation
678-
"norm2.layer_norm": "norm3",
679-
"scale_shift_table": "modulation",
680-
"proj_out": "head.head"
681-
}
682-
683-
# --- 3. Translation Logic ---
684-
if scan_layers:
685-
# Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
686-
if nnx_path_str.startswith("blocks."):
687-
inner_suffix = nnx_path_str[len("blocks."):]
688-
if inner_suffix in suffix_map:
689-
return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}"
690-
else:
691-
# Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
692-
m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str)
693-
if m:
694-
idx, inner_suffix = m.group(1), m.group(2)
695-
if inner_suffix in suffix_map:
696-
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
614+
"""
615+
Translates WAN NNX path to Diffusers/LoRA keys.
616+
Verified against wan_utils.py mappings.
617+
"""
697618

698-
return None
619+
# --- 1. Embeddings (Exact Matches) ---
620+
if nnx_path_str == "condition_embedder.text_embedder.linear_1":
621+
return "diffusion_model.text_embedding.0"
622+
if nnx_path_str == "condition_embedder.text_embedder.linear_2":
623+
return "diffusion_model.text_embedding.2"
624+
if nnx_path_str == "condition_embedder.time_embedder.linear_1":
625+
return "diffusion_model.time_embedding.0"
626+
if nnx_path_str == "condition_embedder.time_embedder.linear_2":
627+
return "diffusion_model.time_embedding.2"
628+
if nnx_path_str == "condition_embedder.image_embedder.norm1.layer_norm":
629+
return "diffusion_model.img_emb.proj.0"
630+
if nnx_path_str == "condition_embedder.image_embedder.ff.net_0":
631+
return "diffusion_model.img_emb.proj.1"
632+
if nnx_path_str == "condition_embedder.image_embedder.ff.net_2":
633+
return "diffusion_model.img_emb.proj.3"
634+
if nnx_path_str == "condition_embedder.image_embedder.norm2.layer_norm":
635+
return "diffusion_model.img_emb.proj.4"
636+
if nnx_path_str == "patch_embedding":
637+
return "diffusion_model.patch_embedding"
638+
if nnx_path_str == "proj_out":
639+
return "diffusion_model.head.head"
640+
if nnx_path_str == "condition_embedder.time_proj":
641+
return "diffusion_model.time_projection.1"
642+
643+
# --- 2. Map NNX Suffixes to LoRA Suffixes ---
644+
suffix_map = {
645+
# Self Attention (attn1)
646+
"attn1.query": "self_attn.q",
647+
"attn1.key": "self_attn.k",
648+
"attn1.value": "self_attn.v",
649+
"attn1.proj_attn": "self_attn.o",
650+
# Self Attention Norms (QK Norm)
651+
"attn1.norm_q": "self_attn.norm_q",
652+
"attn1.norm_k": "self_attn.norm_k",
653+
# Cross Attention (attn2)
654+
"attn2.query": "cross_attn.q",
655+
"attn2.key": "cross_attn.k",
656+
"attn2.value": "cross_attn.v",
657+
"attn2.proj_attn": "cross_attn.o",
658+
# Cross Attention Norms (QK Norm)
659+
"attn2.norm_q": "cross_attn.norm_q",
660+
"attn2.norm_k": "cross_attn.norm_k",
661+
# Cross Attention img
662+
"attn2.add_k_proj": "cross_attn.k_img",
663+
"attn2.add_v_proj": "cross_attn.v_img",
664+
"attn2.norm_added_k": "cross_attn.norm_k_img",
665+
# Feed Forward (ffn)
666+
"ffn.act_fn.proj": "ffn.0", # Up proj
667+
"ffn.proj_out": "ffn.2", # Down proj
668+
# Global Norms & Modulation
669+
"norm2.layer_norm": "norm3",
670+
"scale_shift_table": "modulation",
671+
"proj_out": "head.head",
672+
}
699673

674+
# --- 3. Translation Logic ---
675+
if scan_layers:
676+
# Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
677+
if nnx_path_str.startswith("blocks."):
678+
inner_suffix = nnx_path_str[len("blocks.") :]
679+
if inner_suffix in suffix_map:
680+
return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}"
681+
else:
682+
# Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
683+
m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str)
684+
if m:
685+
idx, inner_suffix = m.group(1), m.group(2)
686+
if inner_suffix in suffix_map:
687+
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
688+
689+
return None

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .. import max_logging
2222
from . import lora_conversion_utils
2323

24+
2425
class Wan2_1NnxLoraLoader(LoRABaseMixin):
2526
"""
2627
Handles loading LoRA weights into NNX-based WAN 2.1 model.
@@ -44,21 +45,21 @@ def load_lora_weights(
4445
lora_loader = StableDiffusionLoraLoaderMixin()
4546

4647
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
48+
4749
def translate_fn(nnx_path_str):
4850
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
4951

5052
# Handle high noise model
5153
if hasattr(pipeline, "transformer") and transformer_weight_name:
52-
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
53-
h_state_dict, _ = lora_loader.lora_state_dict(
54-
lora_model_path, weight_name=transformer_weight_name, **kwargs
55-
)
56-
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn)
54+
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
55+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56+
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn)
5757
else:
58-
max_logging.log("transformer not found or no weight name provided for LoRA.")
58+
max_logging.log("transformer not found or no weight name provided for LoRA.")
5959

6060
return pipeline
6161

62+
6263
class Wan2_2NnxLoraLoader(LoRABaseMixin):
6364
"""
6465
Handles loading LoRA weights into NNX-based WAN 2.2 model.
@@ -83,27 +84,24 @@ def load_lora_weights(
8384
lora_loader = StableDiffusionLoraLoaderMixin()
8485

8586
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
87+
8688
def translate_fn(nnx_path_str: str):
8789
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
8890

8991
# Handle high noise model
9092
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
91-
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
92-
h_state_dict, _ = lora_loader.lora_state_dict(
93-
lora_model_path, weight_name=high_noise_weight_name, **kwargs
94-
)
95-
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn)
93+
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
94+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs)
95+
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn)
9696
else:
97-
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
97+
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
9898

9999
# Handle low noise model
100100
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
101-
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
102-
l_state_dict, _ = lora_loader.lora_state_dict(
103-
lora_model_path, weight_name=low_noise_weight_name, **kwargs
104-
)
105-
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn)
101+
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
102+
l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs)
103+
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn)
106104
else:
107-
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")
105+
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")
108106

109107
return pipeline

0 commit comments

Comments
 (0)