Skip to content

Commit b13347c

Browse files
committed
Add LoRA support for WAN models
1 parent ad56886 commit b13347c

9 files changed

Lines changed: 782 additions & 22 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,14 @@ lightning_repo: ""
317317
lightning_ckpt: ""
318318

319319
# LoRA parameters
320+
enable_lora: False
320321
# Values are lists to support multiple LoRA loading during inference in the future.
321322
lora_config: {
322-
lora_model_name_or_path: [],
323-
weight_name: [],
324-
adapter_name: [],
325-
scale: [],
323+
rank: [64],
324+
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
325+
weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"],
326+
adapter_name: ["wan21-distill-lora"],
327+
scale: [1.0],
326328
from_pt: []
327329
}
328330
# Ex with values:

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,15 @@ lightning_repo: ""
316316
lightning_ckpt: ""
317317

318318
# LoRA parameters
319+
enable_lora: False
319320
# Values are lists to support multiple LoRA loading during inference in the future.
320321
lora_config: {
321-
lora_model_name_or_path: [],
322-
weight_name: [],
323-
adapter_name: [],
324-
scale: [],
322+
rank: [64],
323+
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"],
324+
high_noise_weight_name: ["wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors"],
325+
low_noise_weight_name: ["wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors"],
326+
adapter_name: ["wan22-distill-lora"],
327+
scale: [1.0],
325328
from_pt: []
326329
}
327330
# Ex with values:

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ profiler_steps: 10
276276
enable_jax_named_scopes: False
277277

278278
# Generation parameters
279-
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
280-
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
279+
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
280+
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
281281
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
282282
do_classifier_free_guidance: True
283283
height: 480
@@ -300,12 +300,14 @@ lightning_repo: ""
300300
lightning_ckpt: ""
301301

302302
# LoRA parameters
303+
enable_lora: False
303304
# Values are lists to support multiple LoRA loading during inference in the future.
304305
lora_config: {
305-
lora_model_name_or_path: [],
306-
weight_name: [],
307-
adapter_name: [],
308-
scale: [],
306+
rank: [64, 32],
307+
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras", "starsfriday/Wan2.1-Divine-Power-LoRA"],
308+
weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors", "divine-power.safetensors"],
309+
adapter_name: ["wan21-distill-lora-i2v", "divine-power-lora"],
310+
scale: [1.0, 1.0],
309311
from_pt: []
310312
}
311313
# Ex with values:

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ profiler_steps: 10
277277
enable_jax_named_scopes: False
278278

279279
# Generation parameters
280-
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
281-
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
280+
prompt: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
281+
prompt_2: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
282282
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
283283
do_classifier_free_guidance: True
284284
height: 480
@@ -288,10 +288,10 @@ flow_shift: 3.0
288288

289289
# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py
290290
# guidance scale factor for low noise transformer
291-
guidance_scale_low: 3.0
291+
guidance_scale_low: 3.0
292292

293293
# guidance scale factor for high noise transformer
294-
guidance_scale_high: 4.0
294+
guidance_scale_high: 4.0
295295

296296
# The timestep threshold. If `t` is at or above this value,
297297
# the `high_noise_model` is considered as the required model.
@@ -312,12 +312,15 @@ lightning_repo: ""
312312
lightning_ckpt: ""
313313

314314
# LoRA parameters
315+
enable_lora: False
315316
# Values are lists to support multiple LoRA loading during inference in the future.
316317
lora_config: {
317-
lora_model_name_or_path: [],
318-
weight_name: [],
319-
adapter_name: [],
320-
scale: [],
318+
rank: [64, 16],
319+
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras", "ostris/wan22_i2v_14b_orbit_shot_lora"],
320+
high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", "wan22_14b_i2v_orbit_high_noise.safetensors"],
321+
low_noise_weight_name: ["wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", "wan22_14b_i2v_orbit_low_noise.safetensors"], # Empty or "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"
322+
adapter_name: ["wan22-distill-lora", "wan22-orbit-lora"],
323+
scale: [1.0, 1.0],
321324
from_pt: []
322325
}
323326
# Ex with values:

src/maxdiffusion/generate_wan.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud import storage
2929
import flax
3030
from maxdiffusion.common_types import WAN2_1, WAN2_2
31+
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader
3132

3233

3334
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -188,6 +189,43 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
188189
else:
189190
raise ValueError(f"Unsupported model_name for checkpointer: {model_key}")
190191
pipeline, _, _ = checkpoint_loader.load_checkpoint()
192+
193+
# If LoRA is specified, inject layers and load weights.
194+
if (
195+
config.enable_lora
196+
and hasattr(config, "lora_config")
197+
and config.lora_config
198+
and config.lora_config["lora_model_name_or_path"]
199+
):
200+
if model_key == WAN2_1:
201+
lora_loader = Wan2_1NNXLoraLoader()
202+
lora_config = config.lora_config
203+
for i in range(len(lora_config["lora_model_name_or_path"])):
204+
pipeline = lora_loader.load_lora_weights(
205+
pipeline,
206+
lora_config["lora_model_name_or_path"][i],
207+
transformer_weight_name=lora_config["weight_name"][i],
208+
rank=lora_config["rank"][i],
209+
scale=lora_config["scale"][i],
210+
scan_layers=config.scan_layers,
211+
dtype=config.weights_dtype,
212+
)
213+
214+
if model_key == WAN2_2:
215+
lora_loader = Wan2_2NNXLoraLoader()
216+
lora_config = config.lora_config
217+
for i in range(len(lora_config["lora_model_name_or_path"])):
218+
pipeline = lora_loader.load_lora_weights(
219+
pipeline,
220+
lora_config["lora_model_name_or_path"][i],
221+
high_noise_weight_name=lora_config["high_noise_weight_name"][i],
222+
low_noise_weight_name=lora_config["low_noise_weight_name"][i],
223+
rank=lora_config["rank"][i],
224+
scale=lora_config["scale"][i],
225+
scan_layers=config.scan_layers,
226+
dtype=config.weights_dtype,
227+
)
228+
191229
s0 = time.perf_counter()
192230

193231
# Using global_batch_size_to_train_on so not to create more config variables

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
1616
from .flux_lora_pipeline import FluxLoraLoaderMixin
17+
from .wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,98 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
608608
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
609609

610610
return new_state_dict
611+
612+
613+
def preprocess_wan_lora_dict(state_dict):
614+
"""
615+
Preprocesses WAN LoRA dict to convert diff_m to modulation.diff.
616+
"""
617+
new_d = {}
618+
for k, v in state_dict.items():
619+
if k.endswith(".diff_m"):
620+
new_k = k.removesuffix(".diff_m") + ".modulation.diff"
621+
new_d[new_k] = v
622+
else:
623+
new_d[k] = v
624+
return new_d
625+
626+
627+
def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
628+
"""
629+
Translates WAN NNX path to Diffusers/LoRA keys.
630+
Verified against wan_utils.py mappings.
631+
"""
632+
633+
# --- 1. Embeddings (Exact Matches) ---
634+
if nnx_path_str == "condition_embedder.text_embedder.linear_1":
635+
return "diffusion_model.text_embedding.0"
636+
if nnx_path_str == "condition_embedder.text_embedder.linear_2":
637+
return "diffusion_model.text_embedding.2"
638+
if nnx_path_str == "condition_embedder.time_embedder.linear_1":
639+
return "diffusion_model.time_embedding.0"
640+
if nnx_path_str == "condition_embedder.time_embedder.linear_2":
641+
return "diffusion_model.time_embedding.2"
642+
if nnx_path_str == "condition_embedder.image_embedder.norm1.layer_norm":
643+
return "diffusion_model.img_emb.proj.0"
644+
if nnx_path_str == "condition_embedder.image_embedder.ff.net_0":
645+
return "diffusion_model.img_emb.proj.1"
646+
if nnx_path_str == "condition_embedder.image_embedder.ff.net_2":
647+
return "diffusion_model.img_emb.proj.3"
648+
if nnx_path_str == "condition_embedder.image_embedder.norm2.layer_norm":
649+
return "diffusion_model.img_emb.proj.4"
650+
if nnx_path_str == "patch_embedding":
651+
return "diffusion_model.patch_embedding"
652+
if nnx_path_str == "proj_out":
653+
return "diffusion_model.head.head"
654+
if nnx_path_str == "scale_shift_table":
655+
return "diffusion_model.head.modulation"
656+
if nnx_path_str == "condition_embedder.time_proj":
657+
return "diffusion_model.time_projection.1"
658+
659+
# --- 2. Map NNX Suffixes to LoRA Suffixes ---
660+
suffix_map = {
661+
# Self Attention (attn1)
662+
"attn1.query": "self_attn.q",
663+
"attn1.key": "self_attn.k",
664+
"attn1.value": "self_attn.v",
665+
"attn1.proj_attn": "self_attn.o",
666+
# Self Attention Norms (QK Norm)
667+
"attn1.norm_q": "self_attn.norm_q",
668+
"attn1.norm_k": "self_attn.norm_k",
669+
# Cross Attention (attn2)
670+
"attn2.query": "cross_attn.q",
671+
"attn2.key": "cross_attn.k",
672+
"attn2.value": "cross_attn.v",
673+
"attn2.proj_attn": "cross_attn.o",
674+
# Cross Attention Norms (QK Norm)
675+
"attn2.norm_q": "cross_attn.norm_q",
676+
"attn2.norm_k": "cross_attn.norm_k",
677+
# Cross Attention img
678+
"attn2.add_k_proj": "cross_attn.k_img",
679+
"attn2.add_v_proj": "cross_attn.v_img",
680+
"attn2.norm_added_k": "cross_attn.norm_k_img",
681+
# Feed Forward (ffn)
682+
"ffn.act_fn.proj": "ffn.0", # Up proj
683+
"ffn.proj_out": "ffn.2", # Down proj
684+
# Global Norms & Modulation
685+
"norm2.layer_norm": "norm3",
686+
"adaln_scale_shift_table": "modulation",
687+
"proj_out": "head.head",
688+
}
689+
690+
# --- 3. Translation Logic ---
691+
if scan_layers:
692+
# Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
693+
if nnx_path_str.startswith("blocks."):
694+
inner_suffix = nnx_path_str[len("blocks.") :]
695+
if inner_suffix in suffix_map:
696+
return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}"
697+
else:
698+
# Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
699+
m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str)
700+
if m:
701+
idx, inner_suffix = m.group(1), m.group(2)
702+
if inner_suffix in suffix_map:
703+
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
704+
705+
return None
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""NNX-based LoRA loader for WAN models."""
16+
17+
from flax import nnx
18+
from .lora_base import LoRABaseMixin
19+
from .lora_pipeline import StableDiffusionLoraLoaderMixin
20+
from ..models import lora_nnx
21+
from .. import max_logging
22+
from . import lora_conversion_utils
23+
24+
25+
class Wan2_1NNXLoraLoader(LoRABaseMixin):
26+
"""
27+
Handles loading LoRA weights into NNX-based WAN 2.1 model.
28+
Assumes WAN pipeline contains 'transformer'
29+
attributes that are NNX Modules.
30+
"""
31+
32+
def load_lora_weights(
33+
self,
34+
pipeline: nnx.Module,
35+
lora_model_path: str,
36+
transformer_weight_name: str,
37+
rank: int,
38+
scale: float = 1.0,
39+
scan_layers: bool = False,
40+
dtype: str = "float32",
41+
**kwargs,
42+
):
43+
"""
44+
Merges LoRA weights into the pipeline from a checkpoint.
45+
"""
46+
lora_loader = StableDiffusionLoraLoaderMixin()
47+
48+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
49+
50+
def translate_fn(nnx_path_str):
51+
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
52+
53+
if hasattr(pipeline, "transformer") and transformer_weight_name:
54+
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
55+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56+
h_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(h_state_dict)
57+
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
58+
else:
59+
max_logging.log("transformer not found or no weight name provided for LoRA.")
60+
61+
return pipeline
62+
63+
64+
class Wan2_2NNXLoraLoader(LoRABaseMixin):
65+
"""
66+
Handles loading LoRA weights into NNX-based WAN 2.2 model.
67+
Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer'
68+
attributes that are NNX Modules.
69+
"""
70+
71+
def load_lora_weights(
72+
self,
73+
pipeline: nnx.Module,
74+
lora_model_path: str,
75+
high_noise_weight_name: str,
76+
low_noise_weight_name: str,
77+
rank: int,
78+
scale: float = 1.0,
79+
scan_layers: bool = False,
80+
dtype: str = "float32",
81+
**kwargs,
82+
):
83+
"""
84+
Merges LoRA weights into the pipeline from a checkpoint.
85+
"""
86+
lora_loader = StableDiffusionLoraLoaderMixin()
87+
88+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
89+
90+
def translate_fn(nnx_path_str: str):
91+
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
92+
93+
# Handle high noise model
94+
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
95+
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
96+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs)
97+
h_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(h_state_dict)
98+
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
99+
else:
100+
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
101+
102+
# Handle low noise model
103+
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
104+
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
105+
l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs)
106+
l_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(l_state_dict)
107+
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn, dtype=dtype)
108+
else:
109+
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")
110+
111+
return pipeline

0 commit comments

Comments
 (0)