Skip to content

Commit e1b7221

Browse files
committed
Add LoRA support for WAN models
1 parent 5a05e75 commit e1b7221

9 files changed

Lines changed: 779 additions & 16 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,14 @@ lightning_repo: ""
317317
lightning_ckpt: ""
318318

319319
# LoRA parameters
320+
enable_lora: False
320321
# Values are lists to support multiple LoRA loading during inference in the future.
321322
lora_config: {
322-
lora_model_name_or_path: [],
323-
weight_name: [],
324-
adapter_name: [],
325-
scale: [],
323+
rank: [64],
324+
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
325+
weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"],
326+
adapter_name: ["wan21-distill-lora"],
327+
scale: [1.0],
326328
from_pt: []
327329
}
328330
# Ex with values:

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,15 @@ lightning_repo: ""
316316
lightning_ckpt: ""
317317

318318
# LoRA parameters
319+
enable_lora: False
319320
# Values are lists to support multiple LoRA loading during inference in the future.
320321
lora_config: {
321-
lora_model_name_or_path: [],
322-
weight_name: [],
323-
adapter_name: [],
324-
scale: [],
322+
rank: [64],
323+
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"],
324+
high_noise_weight_name: ["wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors"],
325+
low_noise_weight_name: ["wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors"],
326+
adapter_name: ["wan22-distill-lora"],
327+
scale: [1.0],
325328
from_pt: []
326329
}
327330
# Ex with values:

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,12 +300,14 @@ lightning_repo: ""
300300
lightning_ckpt: ""
301301

302302
# LoRA parameters
303+
enable_lora: False
303304
# Values are lists to support multiple LoRA loading during inference in the future.
304305
lora_config: {
305-
lora_model_name_or_path: [],
306-
weight_name: [],
307-
adapter_name: [],
308-
scale: [],
306+
rank: [64],
307+
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
308+
weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"],
309+
adapter_name: ["wan21-distill-lora-i2v"],
310+
scale: [1.0],
309311
from_pt: []
310312
}
311313
# Ex with values:

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,15 @@ lightning_repo: ""
312312
lightning_ckpt: ""
313313

314314
# LoRA parameters
315+
enable_lora: False
315316
# Values are lists to support multiple LoRA loading during inference in the future.
316317
lora_config: {
317-
lora_model_name_or_path: [],
318-
weight_name: [],
319-
adapter_name: [],
320-
scale: [],
318+
rank: [64],
319+
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"],
320+
high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors"],
321+
low_noise_weight_name: ["wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"],
322+
adapter_name: ["wan22-distill-lora"],
323+
scale: [1.0],
321324
from_pt: []
322325
}
323326
# Ex with values:

src/maxdiffusion/generate_wan.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud import storage
2929
import flax
3030
from maxdiffusion.common_types import WAN2_1, WAN2_2
31+
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
3132

3233

3334
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -190,6 +191,42 @@ def run(config, pipeline=None, filename_prefix=""):
190191
else:
191192
raise ValueError(f"Unsupported model_name for checkpointer: {model_key}")
192193
pipeline, _, _ = checkpoint_loader.load_checkpoint()
194+
195+
# If LoRA is specified, inject layers and load weights.
196+
if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]:
197+
if model_key == WAN2_1:
198+
lora_loader = Wan2_1NnxLoraLoader()
199+
lora_config = config.lora_config
200+
201+
if len(lora_config["lora_model_name_or_path"]) > 1:
202+
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
203+
204+
pipeline = lora_loader.load_lora_weights(
205+
pipeline,
206+
lora_config["lora_model_name_or_path"][0],
207+
transformer_weight_name=lora_config["weight_name"][0],
208+
rank=lora_config["rank"][0],
209+
scale=lora_config["scale"][0],
210+
scan_layers=config.scan_layers,
211+
)
212+
213+
if model_key == WAN2_2:
214+
lora_loader = Wan2_2NnxLoraLoader()
215+
lora_config = config.lora_config
216+
217+
if len(lora_config["lora_model_name_or_path"]) > 1:
218+
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
219+
220+
pipeline = lora_loader.load_lora_weights(
221+
pipeline,
222+
lora_config["lora_model_name_or_path"][0],
223+
high_noise_weight_name=lora_config["high_noise_weight_name"][0],
224+
low_noise_weight_name=lora_config["low_noise_weight_name"][0],
225+
rank=lora_config["rank"][0],
226+
scale=lora_config["scale"][0],
227+
scan_layers=config.scan_layers,
228+
)
229+
193230
s0 = time.perf_counter()
194231

195232
# Using global_batch_size_to_train_on so not to create more config variables

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
1616
from .flux_lora_pipeline import FluxLoraLoaderMixin
17+
from .wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,92 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
608608
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
609609

610610
return new_state_dict
611+
612+
613+
def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
614+
"""
615+
Translates WAN NNX path to Diffusers/LoRA keys.
616+
Verified against wan_utils.py mappings.
617+
"""
618+
619+
# --- 1. Embeddings (Exact Matches) ---
620+
if nnx_path_str == 'condition_embedder.text_embedder.linear_1':
621+
return 'diffusion_model.text_embedding.0'
622+
if nnx_path_str == 'condition_embedder.text_embedder.linear_2':
623+
return 'diffusion_model.text_embedding.2'
624+
if nnx_path_str == 'condition_embedder.time_embedder.linear_1':
625+
return 'diffusion_model.time_embedding.0'
626+
if nnx_path_str == 'condition_embedder.time_embedder.linear_2':
627+
return 'diffusion_model.time_embedding.2'
628+
if nnx_path_str == 'condition_embedder.image_embedder.norm1.layer_norm':
629+
return 'diffusion_model.img_emb.proj.0'
630+
if nnx_path_str == 'condition_embedder.image_embedder.ff.net_0':
631+
return 'diffusion_model.img_emb.proj.1'
632+
if nnx_path_str == 'condition_embedder.image_embedder.ff.net_2':
633+
return 'diffusion_model.img_emb.proj.3'
634+
if nnx_path_str == 'condition_embedder.image_embedder.norm2.layer_norm':
635+
return 'diffusion_model.img_emb.proj.4'
636+
if nnx_path_str == 'patch_embedding':
637+
return 'diffusion_model.patch_embedding'
638+
if nnx_path_str == 'proj_out':
639+
return 'diffusion_model.head.head'
640+
if nnx_path_str == 'condition_embedder.time_proj':
641+
return 'diffusion_model.time_projection.1'
642+
643+
644+
645+
646+
# --- 2. Map NNX Suffixes to LoRA Suffixes ---
647+
suffix_map = {
648+
# Self Attention (attn1)
649+
"attn1.query": "self_attn.q",
650+
"attn1.key": "self_attn.k",
651+
"attn1.value": "self_attn.v",
652+
"attn1.proj_attn": "self_attn.o",
653+
654+
# Self Attention Norms (QK Norm)
655+
"attn1.norm_q": "self_attn.norm_q",
656+
"attn1.norm_k": "self_attn.norm_k",
657+
658+
# Cross Attention (attn2)
659+
"attn2.query": "cross_attn.q",
660+
"attn2.key": "cross_attn.k",
661+
"attn2.value": "cross_attn.v",
662+
"attn2.proj_attn": "cross_attn.o",
663+
664+
# Cross Attention Norms (QK Norm)
665+
"attn2.norm_q": "cross_attn.norm_q",
666+
"attn2.norm_k": "cross_attn.norm_k",
667+
668+
# Cross Attention img
669+
"attn2.add_k_proj": "cross_attn.k_img",
670+
"attn2.add_v_proj": "cross_attn.v_img",
671+
"attn2.norm_added_k": "cross_attn.norm_k_img",
672+
673+
# Feed Forward (ffn)
674+
"ffn.act_fn.proj": "ffn.0", # Up proj
675+
"ffn.proj_out": "ffn.2", # Down proj
676+
677+
# Global Norms & Modulation
678+
"norm2.layer_norm": "norm3",
679+
"scale_shift_table": "modulation",
680+
"proj_out": "head.head"
681+
}
682+
683+
# --- 3. Translation Logic ---
684+
if scan_layers:
685+
# Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
686+
if nnx_path_str.startswith("blocks."):
687+
inner_suffix = nnx_path_str[len("blocks."):]
688+
if inner_suffix in suffix_map:
689+
return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}"
690+
else:
691+
# Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
692+
m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str)
693+
if m:
694+
idx, inner_suffix = m.group(1), m.group(2)
695+
if inner_suffix in suffix_map:
696+
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
697+
698+
return None
699+
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""NNX-based LoRA loader for WAN models."""
16+
17+
from flax import nnx
18+
from .lora_base import LoRABaseMixin
19+
from .lora_pipeline import StableDiffusionLoraLoaderMixin
20+
from ..models import lora_nnx
21+
from .. import max_logging
22+
from . import lora_conversion_utils
23+
24+
class Wan2_1NnxLoraLoader(LoRABaseMixin):
25+
"""
26+
Handles loading LoRA weights into NNX-based WAN 2.1 model.
27+
Assumes WAN pipeline contains 'transformer'
28+
attributes that are NNX Modules.
29+
"""
30+
31+
def load_lora_weights(
32+
self,
33+
pipeline: nnx.Module,
34+
lora_model_path: str,
35+
transformer_weight_name: str,
36+
rank: int,
37+
scale: float = 1.0,
38+
scan_layers: bool = False,
39+
**kwargs,
40+
):
41+
"""
42+
Merges LoRA weights into the pipeline from a checkpoint.
43+
"""
44+
lora_loader = StableDiffusionLoraLoaderMixin()
45+
46+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
47+
def translate_fn(nnx_path_str):
48+
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
49+
50+
# Handle high noise model
51+
if hasattr(pipeline, "transformer") and transformer_weight_name:
52+
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
53+
h_state_dict, _ = lora_loader.lora_state_dict(
54+
lora_model_path, weight_name=transformer_weight_name, **kwargs
55+
)
56+
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn)
57+
else:
58+
max_logging.log("transformer not found or no weight name provided for LoRA.")
59+
60+
return pipeline
61+
62+
class Wan2_2NnxLoraLoader(LoRABaseMixin):
63+
"""
64+
Handles loading LoRA weights into NNX-based WAN 2.2 model.
65+
Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer'
66+
attributes that are NNX Modules.
67+
"""
68+
69+
def load_lora_weights(
70+
self,
71+
pipeline: nnx.Module,
72+
lora_model_path: str,
73+
high_noise_weight_name: str,
74+
low_noise_weight_name: str,
75+
rank: int,
76+
scale: float = 1.0,
77+
scan_layers: bool = False,
78+
**kwargs,
79+
):
80+
"""
81+
Merges LoRA weights into the pipeline from a checkpoint.
82+
"""
83+
lora_loader = StableDiffusionLoraLoaderMixin()
84+
85+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
86+
def translate_fn(nnx_path_str: str):
87+
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
88+
89+
# Handle high noise model
90+
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
91+
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
92+
h_state_dict, _ = lora_loader.lora_state_dict(
93+
lora_model_path, weight_name=high_noise_weight_name, **kwargs
94+
)
95+
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn)
96+
else:
97+
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
98+
99+
# Handle low noise model
100+
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
101+
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
102+
l_state_dict, _ = lora_loader.lora_state_dict(
103+
lora_model_path, weight_name=low_noise_weight_name, **kwargs
104+
)
105+
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn)
106+
else:
107+
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")
108+
109+
return pipeline

0 commit comments

Comments
 (0)