Skip to content

Commit 95fe0ba

Browse files
committed
Add LoRA support for WAN models
1 parent ad56886 commit 95fe0ba

9 files changed

Lines changed: 792 additions & 16 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,14 @@ lightning_repo: ""
317317
lightning_ckpt: ""
318318

319319
# LoRA parameters
320+
enable_lora: False
320321
# Values are lists to support multiple LoRA loading during inference in the future.
321322
lora_config: {
322-
lora_model_name_or_path: [],
323-
weight_name: [],
324-
adapter_name: [],
325-
scale: [],
323+
rank: [64],
324+
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
325+
weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"],
326+
adapter_name: ["wan21-distill-lora"],
327+
scale: [1.0],
326328
from_pt: []
327329
}
328330
# Ex with values:

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,15 @@ lightning_repo: ""
316316
lightning_ckpt: ""
317317

318318
# LoRA parameters
319+
enable_lora: False
319320
# Values are lists to support multiple LoRA loading during inference in the future.
320321
lora_config: {
321-
lora_model_name_or_path: [],
322-
weight_name: [],
323-
adapter_name: [],
324-
scale: [],
322+
rank: [64],
323+
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"],
324+
high_noise_weight_name: ["wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors"],
325+
low_noise_weight_name: ["wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors"],
326+
adapter_name: ["wan22-distill-lora"],
327+
scale: [1.0],
325328
from_pt: []
326329
}
327330
# Ex with values:

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,12 +300,14 @@ lightning_repo: ""
300300
lightning_ckpt: ""
301301

302302
# LoRA parameters
303+
enable_lora: False
303304
# Values are lists to support multiple LoRA loading during inference in the future.
304305
lora_config: {
305-
lora_model_name_or_path: [],
306-
weight_name: [],
307-
adapter_name: [],
308-
scale: [],
306+
rank: [64],
307+
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
308+
weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"],
309+
adapter_name: ["wan21-distill-lora-i2v"],
310+
scale: [1.0],
309311
from_pt: []
310312
}
311313
# Ex with values:

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,15 @@ lightning_repo: ""
312312
lightning_ckpt: ""
313313

314314
# LoRA parameters
315+
enable_lora: False
315316
# Values are lists to support multiple LoRA loading during inference in the future.
316317
lora_config: {
317-
lora_model_name_or_path: [],
318-
weight_name: [],
319-
adapter_name: [],
320-
scale: [],
318+
rank: [64],
319+
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"],
320+
high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors"],
321+
low_noise_weight_name: [""], # Empty or "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"
322+
adapter_name: ["wan22-lightning-lora"],
323+
scale: [1.0],
321324
from_pt: []
322325
}
323326
# Ex with values:

src/maxdiffusion/generate_wan.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud import storage
2929
import flax
3030
from maxdiffusion.common_types import WAN2_1, WAN2_2
31+
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
3132

3233

3334
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -188,6 +189,42 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
188189
else:
189190
raise ValueError(f"Unsupported model_name for checkpointer: {model_key}")
190191
pipeline, _, _ = checkpoint_loader.load_checkpoint()
192+
193+
# If LoRA is specified, inject layers and load weights.
194+
if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]:
195+
if model_key == WAN2_1:
196+
lora_loader = Wan2_1NnxLoraLoader()
197+
lora_config = config.lora_config
198+
199+
if len(lora_config["lora_model_name_or_path"]) > 1:
200+
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
201+
202+
pipeline = lora_loader.load_lora_weights(
203+
pipeline,
204+
lora_config["lora_model_name_or_path"][0],
205+
transformer_weight_name=lora_config["weight_name"][0],
206+
rank=lora_config["rank"][0],
207+
scale=lora_config["scale"][0],
208+
scan_layers=config.scan_layers,
209+
)
210+
211+
if model_key == WAN2_2:
212+
lora_loader = Wan2_2NnxLoraLoader()
213+
lora_config = config.lora_config
214+
215+
if len(lora_config["lora_model_name_or_path"]) > 1:
216+
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
217+
218+
pipeline = lora_loader.load_lora_weights(
219+
pipeline,
220+
lora_config["lora_model_name_or_path"][0],
221+
high_noise_weight_name=lora_config["high_noise_weight_name"][0],
222+
low_noise_weight_name=lora_config["low_noise_weight_name"][0],
223+
rank=lora_config["rank"][0],
224+
scale=lora_config["scale"][0],
225+
scan_layers=config.scan_layers,
226+
)
227+
191228
s0 = time.perf_counter()
192229

193230
# Using global_batch_size_to_train_on so not to create more config variables

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
1616
from .flux_lora_pipeline import FluxLoraLoaderMixin
17+
from .wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,82 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
608608
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
609609

610610
return new_state_dict
611+
612+
613+
def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
614+
"""
615+
Translates WAN NNX path to Diffusers/LoRA keys.
616+
Verified against wan_utils.py mappings.
617+
"""
618+
619+
# --- 1. Embeddings (Exact Matches) ---
620+
if nnx_path_str == "condition_embedder.text_embedder.linear_1":
621+
return "diffusion_model.text_embedding.0"
622+
if nnx_path_str == "condition_embedder.text_embedder.linear_2":
623+
return "diffusion_model.text_embedding.2"
624+
if nnx_path_str == "condition_embedder.time_embedder.linear_1":
625+
return "diffusion_model.time_embedding.0"
626+
if nnx_path_str == "condition_embedder.time_embedder.linear_2":
627+
return "diffusion_model.time_embedding.2"
628+
if nnx_path_str == "condition_embedder.image_embedder.norm1.layer_norm":
629+
return "diffusion_model.img_emb.proj.0"
630+
if nnx_path_str == "condition_embedder.image_embedder.ff.net_0":
631+
return "diffusion_model.img_emb.proj.1"
632+
if nnx_path_str == "condition_embedder.image_embedder.ff.net_2":
633+
return "diffusion_model.img_emb.proj.3"
634+
if nnx_path_str == "condition_embedder.image_embedder.norm2.layer_norm":
635+
return "diffusion_model.img_emb.proj.4"
636+
if nnx_path_str == "patch_embedding":
637+
return "diffusion_model.patch_embedding"
638+
if nnx_path_str == "proj_out":
639+
return "diffusion_model.head.head"
640+
if nnx_path_str == "condition_embedder.time_proj":
641+
return "diffusion_model.time_projection.1"
642+
643+
# --- 2. Map NNX Suffixes to LoRA Suffixes ---
644+
suffix_map = {
645+
# Self Attention (attn1)
646+
"attn1.query": "self_attn.q",
647+
"attn1.key": "self_attn.k",
648+
"attn1.value": "self_attn.v",
649+
"attn1.proj_attn": "self_attn.o",
650+
# Self Attention Norms (QK Norm)
651+
"attn1.norm_q": "self_attn.norm_q",
652+
"attn1.norm_k": "self_attn.norm_k",
653+
# Cross Attention (attn2)
654+
"attn2.query": "cross_attn.q",
655+
"attn2.key": "cross_attn.k",
656+
"attn2.value": "cross_attn.v",
657+
"attn2.proj_attn": "cross_attn.o",
658+
# Cross Attention Norms (QK Norm)
659+
"attn2.norm_q": "cross_attn.norm_q",
660+
"attn2.norm_k": "cross_attn.norm_k",
661+
# Cross Attention img
662+
"attn2.add_k_proj": "cross_attn.k_img",
663+
"attn2.add_v_proj": "cross_attn.v_img",
664+
"attn2.norm_added_k": "cross_attn.norm_k_img",
665+
# Feed Forward (ffn)
666+
"ffn.act_fn.proj": "ffn.0", # Up proj
667+
"ffn.proj_out": "ffn.2", # Down proj
668+
# Global Norms & Modulation
669+
"norm2.layer_norm": "norm3",
670+
"scale_shift_table": "modulation",
671+
"proj_out": "head.head",
672+
}
673+
674+
# --- 3. Translation Logic ---
675+
if scan_layers:
676+
# Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
677+
if nnx_path_str.startswith("blocks."):
678+
inner_suffix = nnx_path_str[len("blocks.") :]
679+
if inner_suffix in suffix_map:
680+
return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}"
681+
else:
682+
# Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
683+
m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str)
684+
if m:
685+
idx, inner_suffix = m.group(1), m.group(2)
686+
if inner_suffix in suffix_map:
687+
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
688+
689+
return None
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""NNX-based LoRA loader for WAN models."""
16+
17+
from flax import nnx
18+
from .lora_base import LoRABaseMixin
19+
from .lora_pipeline import StableDiffusionLoraLoaderMixin
20+
from ..models import lora_nnx
21+
from .. import max_logging
22+
from . import lora_conversion_utils
23+
24+
25+
class Wan2_1NnxLoraLoader(LoRABaseMixin):
26+
"""
27+
Handles loading LoRA weights into NNX-based WAN 2.1 model.
28+
Assumes WAN pipeline contains 'transformer'
29+
attributes that are NNX Modules.
30+
"""
31+
32+
def load_lora_weights(
33+
self,
34+
pipeline: nnx.Module,
35+
lora_model_path: str,
36+
transformer_weight_name: str,
37+
rank: int,
38+
scale: float = 1.0,
39+
scan_layers: bool = False,
40+
**kwargs,
41+
):
42+
"""
43+
Merges LoRA weights into the pipeline from a checkpoint.
44+
"""
45+
lora_loader = StableDiffusionLoraLoaderMixin()
46+
47+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
48+
49+
def translate_fn(nnx_path_str):
50+
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
51+
52+
# Handle high noise model
53+
if hasattr(pipeline, "transformer") and transformer_weight_name:
54+
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
55+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56+
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn)
57+
else:
58+
max_logging.log("transformer not found or no weight name provided for LoRA.")
59+
60+
return pipeline
61+
62+
63+
class Wan2_2NnxLoraLoader(LoRABaseMixin):
64+
"""
65+
Handles loading LoRA weights into NNX-based WAN 2.2 model.
66+
Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer'
67+
attributes that are NNX Modules.
68+
"""
69+
70+
def load_lora_weights(
71+
self,
72+
pipeline: nnx.Module,
73+
lora_model_path: str,
74+
high_noise_weight_name: str,
75+
low_noise_weight_name: str,
76+
rank: int,
77+
scale: float = 1.0,
78+
scan_layers: bool = False,
79+
**kwargs,
80+
):
81+
"""
82+
Merges LoRA weights into the pipeline from a checkpoint.
83+
"""
84+
lora_loader = StableDiffusionLoraLoaderMixin()
85+
86+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
87+
88+
def translate_fn(nnx_path_str: str):
89+
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
90+
91+
# Handle high noise model
92+
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
93+
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
94+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs)
95+
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn)
96+
else:
97+
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
98+
99+
# Handle low noise model
100+
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
101+
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
102+
l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs)
103+
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn)
104+
else:
105+
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")
106+
107+
return pipeline

0 commit comments

Comments
 (0)