Skip to content

Commit 8004e6f

Browse files
committed
Linting and styling changes based on All checks passed!
1 parent f6ae9ae commit 8004e6f

3 files changed

Lines changed: 95 additions & 73 deletions

File tree

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
391391
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
392392
if not is_sparse:
393393
# down_weight is copied to each split
394-
ait_sd.update({k: down_weight for k in ait_down_keys})
394+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
395395

396396
# up_weight is split to each split
397397
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -534,7 +534,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
534534
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
535535

536536
# down_weight is copied to each split
537-
ait_sd.update({k: down_weight for k in ait_down_keys})
537+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
538538

539539
# up_weight is split to each split
540540
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -615,7 +615,7 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
615615
Translates WAN NNX path to Diffusers/LoRA keys.
616616
Verified against wan_utils.py mappings.
617617
"""
618-
618+
619619
# --- 1. Embeddings (Exact Matches) ---
620620
if nnx_path_str == 'condition_embedder.text_embedder.linear_1':
621621
return 'diffusion_model.text_embedding.0'
@@ -640,11 +640,11 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
640640
"attn1.key": "self_attn.k",
641641
"attn1.value": "self_attn.v",
642642
"attn1.proj_attn": "self_attn.o",
643-
643+
644644
# Self Attention Norms (QK Norm) - Added per your request
645645
"attn1.norm_q": "self_attn.norm_q",
646646
"attn1.norm_k": "self_attn.norm_k",
647-
647+
648648
# Cross Attention (attn2)
649649
"attn2.query": "cross_attn.q",
650650
"attn2.key": "cross_attn.k",
@@ -654,22 +654,22 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
654654
# Cross Attention Norms (QK Norm) - Added per your request
655655
"attn2.norm_q": "cross_attn.norm_q",
656656
"attn2.norm_k": "cross_attn.norm_k",
657-
657+
658658
# Feed Forward (ffn)
659659
"ffn.act_fn.proj": "ffn.0", # Up proj
660660
"ffn.proj_out": "ffn.2", # Down proj
661-
661+
662662
# Global Norms & Modulation
663-
"norm2.layer_norm": "norm3",
663+
"norm2.layer_norm": "norm3",
664664
"scale_shift_table": "modulation",
665-
"proj_out": "head.head"
665+
"proj_out": "head.head"
666666
}
667667

668668
# --- 3. Translation Logic ---
669669
if scan_layers:
670670
# Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
671671
if nnx_path_str.startswith("blocks."):
672-
inner_suffix = nnx_path_str[len("blocks."):]
672+
inner_suffix = nnx_path_str[len("blocks."):]
673673
if inner_suffix in suffix_map:
674674
return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}"
675675
else:
@@ -679,6 +679,6 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
679679
idx, inner_suffix = m.group(1), m.group(2)
680680
if inner_suffix in suffix_map:
681681
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
682-
682+
683683
return None
684684

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
"""NNX-based LoRA loader for WAN models."""
1616

1717
from flax import nnx
18-
import jax
19-
import re
2018
from .lora_base import LoRABaseMixin
2119
from .lora_pipeline import StableDiffusionLoraLoaderMixin
2220
from ..models import lora_nnx
@@ -46,7 +44,8 @@ def load_lora_weights(
4644
lora_loader = StableDiffusionLoraLoaderMixin()
4745

4846
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
49-
translate_fn = lambda nnx_path_str: lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
47+
def translate_fn(nnx_path_str):
48+
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
5049

5150
# Handle high noise model
5251
if hasattr(pipeline, "transformer") and transformer_weight_name:
@@ -84,7 +83,8 @@ def load_lora_weights(
8483
lora_loader = StableDiffusionLoraLoaderMixin()
8584

8685
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
87-
translate_fn = lambda nnx_path_str: lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
86+
def translate_fn(nnx_path_str: str):
87+
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
8888

8989
# Handle high noise model
9090
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:

0 commit comments

Comments
 (0)