Skip to content

Commit 76afafa

Browse files
committed
Refactored based on review
1 parent 95fe0ba commit 76afafa

7 files changed

Lines changed: 329 additions & 345 deletions

File tree

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ profiler_steps: 10
276276
enable_jax_named_scopes: False
277277

278278
# Generation parameters
279-
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
280-
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
279+
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
280+
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
281281
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
282282
do_classifier_free_guidance: True
283283
height: 480
@@ -303,11 +303,11 @@ lightning_ckpt: ""
303303
enable_lora: False
304304
# Values are lists to support multiple LoRA loading during inference in the future.
305305
lora_config: {
306-
rank: [64],
307-
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
308-
weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"],
309-
adapter_name: ["wan21-distill-lora-i2v"],
310-
scale: [1.0],
306+
rank: [64, 32],
307+
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras", "starsfriday/Wan2.1-Divine-Power-LoRA"],
308+
weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors", "divine-power.safetensors"],
309+
adapter_name: ["wan21-distill-lora-i2v", "divine-power-lora"],
310+
scale: [1.0, 1.0],
311311
from_pt: []
312312
}
313313
# Ex with values:

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ profiler_steps: 10
277277
enable_jax_named_scopes: False
278278

279279
# Generation parameters
280-
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
281-
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
280+
prompt: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
281+
prompt_2: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
282282
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
283283
do_classifier_free_guidance: True
284284
height: 480
@@ -288,10 +288,10 @@ flow_shift: 3.0
288288

289289
# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py
290290
# guidance scale factor for low noise transformer
291-
guidance_scale_low: 3.0
291+
guidance_scale_low: 3.0
292292

293293
# guidance scale factor for high noise transformer
294-
guidance_scale_high: 4.0
294+
guidance_scale_high: 4.0
295295

296296
# The timestep threshold. If `t` is at or above this value,
297297
# the `high_noise_model` is considered as the required model.
@@ -315,12 +315,12 @@ lightning_ckpt: ""
315315
enable_lora: False
316316
# Values are lists to support multiple LoRA loading during inference in the future.
317317
lora_config: {
318-
rank: [64],
319-
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"],
320-
high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors"],
321-
low_noise_weight_name: [""], # Empty or "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"
322-
adapter_name: ["wan22-lightning-lora"],
323-
scale: [1.0],
318+
rank: [64, 16],
319+
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras", "ostris/wan22_i2v_14b_orbit_shot_lora"],
320+
high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", "wan22_14b_i2v_orbit_high_noise.safetensors"],
321+
low_noise_weight_name: ["wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", "wan22_14b_i2v_orbit_low_noise.safetensors"], # Empty or "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"
322+
adapter_name: ["wan22-distill-lora", "wan22-orbit-lora"],
323+
scale: [1.0, 1.0],
324324
from_pt: []
325325
}
326326
# Ex with values:

src/maxdiffusion/generate_wan.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from google.cloud import storage
2929
import flax
3030
from maxdiffusion.common_types import WAN2_1, WAN2_2
31-
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
31+
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader
3232

3333

3434
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -191,39 +191,40 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
191191
pipeline, _, _ = checkpoint_loader.load_checkpoint()
192192

193193
# If LoRA is specified, inject layers and load weights.
194-
if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]:
194+
if (
195+
config.enable_lora
196+
and hasattr(config, "lora_config")
197+
and config.lora_config
198+
and config.lora_config["lora_model_name_or_path"]
199+
):
195200
if model_key == WAN2_1:
196-
lora_loader = Wan2_1NnxLoraLoader()
201+
lora_loader = Wan2_1NNXLoraLoader()
197202
lora_config = config.lora_config
198-
199-
if len(lora_config["lora_model_name_or_path"]) > 1:
200-
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
201-
202-
pipeline = lora_loader.load_lora_weights(
203-
pipeline,
204-
lora_config["lora_model_name_or_path"][0],
205-
transformer_weight_name=lora_config["weight_name"][0],
206-
rank=lora_config["rank"][0],
207-
scale=lora_config["scale"][0],
208-
scan_layers=config.scan_layers,
209-
)
203+
for i in range(len(lora_config["lora_model_name_or_path"])):
204+
pipeline = lora_loader.load_lora_weights(
205+
pipeline,
206+
lora_config["lora_model_name_or_path"][i],
207+
transformer_weight_name=lora_config["weight_name"][i],
208+
rank=lora_config["rank"][i],
209+
scale=lora_config["scale"][i],
210+
scan_layers=config.scan_layers,
211+
dtype=config.weights_dtype,
212+
)
210213

211214
if model_key == WAN2_2:
212-
lora_loader = Wan2_2NnxLoraLoader()
215+
lora_loader = Wan2_2NNXLoraLoader()
213216
lora_config = config.lora_config
214-
215-
if len(lora_config["lora_model_name_or_path"]) > 1:
216-
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
217-
218-
pipeline = lora_loader.load_lora_weights(
219-
pipeline,
220-
lora_config["lora_model_name_or_path"][0],
221-
high_noise_weight_name=lora_config["high_noise_weight_name"][0],
222-
low_noise_weight_name=lora_config["low_noise_weight_name"][0],
223-
rank=lora_config["rank"][0],
224-
scale=lora_config["scale"][0],
225-
scan_layers=config.scan_layers,
226-
)
217+
for i in range(len(lora_config["lora_model_name_or_path"])):
218+
pipeline = lora_loader.load_lora_weights(
219+
pipeline,
220+
lora_config["lora_model_name_or_path"][i],
221+
high_noise_weight_name=lora_config["high_noise_weight_name"][i],
222+
low_noise_weight_name=lora_config["low_noise_weight_name"][i],
223+
rank=lora_config["rank"][i],
224+
scale=lora_config["scale"][i],
225+
scan_layers=config.scan_layers,
226+
dtype=config.weights_dtype,
227+
)
227228

228229
s0 = time.perf_counter()
229230

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
1616
from .flux_lora_pipeline import FluxLoraLoaderMixin
17-
from .wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
17+
from .wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,20 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
610610
return new_state_dict
611611

612612

613+
def preprocess_wan_lora_dict(state_dict):
614+
"""
615+
Preprocesses WAN LoRA dict to convert diff_m to modulation.diff.
616+
"""
617+
new_d = {}
618+
for k, v in state_dict.items():
619+
if k.endswith(".diff_m"):
620+
new_k = k.removesuffix(".diff_m") + ".modulation.diff"
621+
new_d[new_k] = v
622+
else:
623+
new_d[k] = v
624+
return new_d
625+
626+
613627
def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
614628
"""
615629
Translates WAN NNX path to Diffusers/LoRA keys.
@@ -637,6 +651,8 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
637651
return "diffusion_model.patch_embedding"
638652
if nnx_path_str == "proj_out":
639653
return "diffusion_model.head.head"
654+
if nnx_path_str == "scale_shift_table":
655+
return "diffusion_model.head.modulation"
640656
if nnx_path_str == "condition_embedder.time_proj":
641657
return "diffusion_model.time_projection.1"
642658

@@ -667,7 +683,7 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
667683
"ffn.proj_out": "ffn.2", # Down proj
668684
# Global Norms & Modulation
669685
"norm2.layer_norm": "norm3",
670-
"scale_shift_table": "modulation",
686+
"adaln_scale_shift_table": "modulation",
671687
"proj_out": "head.head",
672688
}
673689

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from . import lora_conversion_utils
2323

2424

25-
class Wan2_1NnxLoraLoader(LoRABaseMixin):
25+
class Wan2_1NNXLoraLoader(LoRABaseMixin):
2626
"""
2727
Handles loading LoRA weights into NNX-based WAN 2.1 model.
2828
Assumes WAN pipeline contains 'transformer'
@@ -37,6 +37,7 @@ def load_lora_weights(
3737
rank: int,
3838
scale: float = 1.0,
3939
scan_layers: bool = False,
40+
dtype: str = "float32",
4041
**kwargs,
4142
):
4243
"""
@@ -49,18 +50,18 @@ def load_lora_weights(
4950
def translate_fn(nnx_path_str):
5051
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
5152

52-
# Handle high noise model
5353
if hasattr(pipeline, "transformer") and transformer_weight_name:
5454
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
5555
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56-
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn)
56+
h_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(h_state_dict)
57+
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
5758
else:
5859
max_logging.log("transformer not found or no weight name provided for LoRA.")
5960

6061
return pipeline
6162

6263

63-
class Wan2_2NnxLoraLoader(LoRABaseMixin):
64+
class Wan2_2NNXLoraLoader(LoRABaseMixin):
6465
"""
6566
Handles loading LoRA weights into NNX-based WAN 2.2 model.
6667
Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer'
@@ -76,6 +77,7 @@ def load_lora_weights(
7677
rank: int,
7778
scale: float = 1.0,
7879
scan_layers: bool = False,
80+
dtype: str = "float32",
7981
**kwargs,
8082
):
8183
"""
@@ -92,15 +94,17 @@ def translate_fn(nnx_path_str: str):
9294
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
9395
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
9496
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs)
95-
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn)
97+
h_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(h_state_dict)
98+
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
9699
else:
97100
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
98101

99102
# Handle low noise model
100103
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
101104
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
102105
l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs)
103-
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn)
106+
l_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(l_state_dict)
107+
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn, dtype=dtype)
104108
else:
105109
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")
106110

0 commit comments

Comments
 (0)