Skip to content

Commit 7f018e4

Browse files
committed
Refactored based on review
1 parent 95fe0ba commit 7f018e4

5 files changed

Lines changed: 276 additions & 335 deletions

File tree

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ profiler_steps: 10
276276
enable_jax_named_scopes: False
277277

278278
# Generation parameters
279-
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
280-
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
279+
prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
280+
prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
281281
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
282282
do_classifier_free_guidance: True
283283
height: 480
@@ -303,11 +303,11 @@ lightning_ckpt: ""
303303
enable_lora: False
304304
# Values are lists to support multiple LoRA loading during inference in the future.
305305
lora_config: {
306-
rank: [64],
307-
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
308-
weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"],
309-
adapter_name: ["wan21-distill-lora-i2v"],
310-
scale: [1.0],
306+
rank: [64, 32],
307+
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras", "starsfriday/Wan2.1-Divine-Power-LoRA"],
308+
weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors", "divine-power.safetensors"],
309+
adapter_name: ["wan21-distill-lora-i2v", "divine-power-lora"],
310+
scale: [1.0, 1.0],
311311
from_pt: []
312312
}
313313
# Ex with values:

src/maxdiffusion/generate_wan.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from google.cloud import storage
2929
import flax
3030
from maxdiffusion.common_types import WAN2_1, WAN2_2
31-
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
31+
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader
3232

3333

3434
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -191,39 +191,40 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
191191
pipeline, _, _ = checkpoint_loader.load_checkpoint()
192192

193193
# If LoRA is specified, inject layers and load weights.
194-
if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]:
194+
if (
195+
config.enable_lora
196+
and hasattr(config, "lora_config")
197+
and config.lora_config
198+
and config.lora_config["lora_model_name_or_path"]
199+
):
195200
if model_key == WAN2_1:
196-
lora_loader = Wan2_1NnxLoraLoader()
201+
lora_loader = Wan2_1NNXLoraLoader()
197202
lora_config = config.lora_config
198-
199-
if len(lora_config["lora_model_name_or_path"]) > 1:
200-
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
201-
202-
pipeline = lora_loader.load_lora_weights(
203-
pipeline,
204-
lora_config["lora_model_name_or_path"][0],
205-
transformer_weight_name=lora_config["weight_name"][0],
206-
rank=lora_config["rank"][0],
207-
scale=lora_config["scale"][0],
208-
scan_layers=config.scan_layers,
209-
)
203+
for i in range(len(lora_config["lora_model_name_or_path"])):
204+
pipeline = lora_loader.load_lora_weights(
205+
pipeline,
206+
lora_config["lora_model_name_or_path"][i],
207+
transformer_weight_name=lora_config["weight_name"][i],
208+
rank=lora_config["rank"][i],
209+
scale=lora_config["scale"][i],
210+
scan_layers=config.scan_layers,
211+
dtype=config.weights_dtype,
212+
)
210213

211214
if model_key == WAN2_2:
212-
lora_loader = Wan2_2NnxLoraLoader()
215+
lora_loader = Wan2_2NNXLoraLoader()
213216
lora_config = config.lora_config
214-
215-
if len(lora_config["lora_model_name_or_path"]) > 1:
216-
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
217-
218-
pipeline = lora_loader.load_lora_weights(
219-
pipeline,
220-
lora_config["lora_model_name_or_path"][0],
221-
high_noise_weight_name=lora_config["high_noise_weight_name"][0],
222-
low_noise_weight_name=lora_config["low_noise_weight_name"][0],
223-
rank=lora_config["rank"][0],
224-
scale=lora_config["scale"][0],
225-
scan_layers=config.scan_layers,
226-
)
217+
for i in range(len(lora_config["lora_model_name_or_path"])):
218+
pipeline = lora_loader.load_lora_weights(
219+
pipeline,
220+
lora_config["lora_model_name_or_path"][i],
221+
high_noise_weight_name=lora_config["high_noise_weight_name"][i],
222+
low_noise_weight_name=lora_config["low_noise_weight_name"][i],
223+
rank=lora_config["rank"][i],
224+
scale=lora_config["scale"][i],
225+
scan_layers=config.scan_layers,
226+
dtype=config.weights_dtype,
227+
)
227228

228229
s0 = time.perf_counter()
229230

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
1616
from .flux_lora_pipeline import FluxLoraLoaderMixin
17-
from .wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
17+
from .wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from . import lora_conversion_utils
2323

2424

25-
class Wan2_1NnxLoraLoader(LoRABaseMixin):
25+
class Wan2_1NNXLoraLoader(LoRABaseMixin):
2626
"""
2727
Handles loading LoRA weights into NNX-based WAN 2.1 model.
2828
Assumes WAN pipeline contains 'transformer'
@@ -37,6 +37,7 @@ def load_lora_weights(
3737
rank: int,
3838
scale: float = 1.0,
3939
scan_layers: bool = False,
40+
dtype: str = "float32",
4041
**kwargs,
4142
):
4243
"""
@@ -53,14 +54,14 @@ def translate_fn(nnx_path_str):
5354
if hasattr(pipeline, "transformer") and transformer_weight_name:
5455
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
5556
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56-
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn)
57+
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
5758
else:
5859
max_logging.log("transformer not found or no weight name provided for LoRA.")
5960

6061
return pipeline
6162

6263

63-
class Wan2_2NnxLoraLoader(LoRABaseMixin):
64+
class Wan2_2NNXLoraLoader(LoRABaseMixin):
6465
"""
6566
Handles loading LoRA weights into NNX-based WAN 2.2 model.
6667
Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer'
@@ -76,6 +77,7 @@ def load_lora_weights(
7677
rank: int,
7778
scale: float = 1.0,
7879
scan_layers: bool = False,
80+
dtype: str = "float32",
7981
**kwargs,
8082
):
8183
"""
@@ -92,15 +94,15 @@ def translate_fn(nnx_path_str: str):
9294
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
9395
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
9496
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs)
95-
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn)
97+
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
9698
else:
9799
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
98100

99101
# Handle low noise model
100102
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
101103
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
102104
l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs)
103-
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn)
105+
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn, dtype=dtype)
104106
else:
105107
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")
106108

0 commit comments

Comments
 (0)