Skip to content

Commit 0240929

Browse files
committed
LoRA support for Modulation layer in WAN2.2
1 parent aeca1ce commit 0240929

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .lora_pipeline import StableDiffusionLoraLoaderMixin
2020
from ..models import lora_nnx
2121
from .. import max_logging
22-
from . import lora_conversion_utils, preprocess_wan_lora_dict
22+
from . import lora_conversion_utils
2323

2424

2525
class Wan2_1NNXLoraLoader(LoRABaseMixin):
@@ -53,7 +53,7 @@ def translate_fn(nnx_path_str):
5353
if hasattr(pipeline, "transformer") and transformer_weight_name:
5454
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
5555
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56-
h_state_dict = preprocess_wan_lora_dict(h_state_dict)
56+
h_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(h_state_dict)
5757
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
5858
else:
5959
max_logging.log("transformer not found or no weight name provided for LoRA.")
@@ -94,7 +94,7 @@ def translate_fn(nnx_path_str: str):
9494
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
9595
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
9696
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs)
97-
h_state_dict = preprocess_wan_lora_dict(h_state_dict)
97+
h_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(h_state_dict)
9898
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
9999
else:
100100
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
@@ -103,7 +103,7 @@ def translate_fn(nnx_path_str: str):
103103
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
104104
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
105105
l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs)
106-
l_state_dict = preprocess_wan_lora_dict(l_state_dict)
106+
l_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(l_state_dict)
107107
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn, dtype=dtype)
108108
else:
109109
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")

0 commit comments

Comments
 (0)