Skip to content

Commit 655143d

Browse files
committed
Refactored based on review
1 parent 6b2b704 commit 655143d

4 files changed

Lines changed: 253 additions & 323 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from google.cloud import storage
2929
import flax
3030
from maxdiffusion.common_types import WAN2_1, WAN2_2
31-
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
31+
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader
3232

3333

3434
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -195,37 +195,33 @@ def run(config, pipeline=None, filename_prefix=""):
195195
# If LoRA is specified, inject layers and load weights.
196196
if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]:
197197
if model_key == WAN2_1:
198-
lora_loader = Wan2_1NnxLoraLoader()
198+
lora_loader = Wan2_1NNXLoraLoader()
199199
lora_config = config.lora_config
200-
201-
if len(lora_config["lora_model_name_or_path"]) > 1:
202-
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
203-
204-
pipeline = lora_loader.load_lora_weights(
205-
pipeline,
206-
lora_config["lora_model_name_or_path"][0],
207-
transformer_weight_name=lora_config["weight_name"][0],
208-
rank=lora_config["rank"][0],
209-
scale=lora_config["scale"][0],
210-
scan_layers=config.scan_layers,
211-
)
200+
for i in range(len(lora_config["lora_model_name_or_path"])):
201+
pipeline = lora_loader.load_lora_weights(
202+
pipeline,
203+
lora_config["lora_model_name_or_path"][i],
204+
transformer_weight_name=lora_config["weight_name"][i],
205+
rank=lora_config["rank"][i],
206+
scale=lora_config["scale"][i],
207+
scan_layers=config.scan_layers,
208+
dtype=config.weights_dtype,
209+
)
212210

213211
if model_key == WAN2_2:
214-
lora_loader = Wan2_2NnxLoraLoader()
212+
lora_loader = Wan2_2NNXLoraLoader()
215213
lora_config = config.lora_config
216-
217-
if len(lora_config["lora_model_name_or_path"]) > 1:
218-
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")
219-
220-
pipeline = lora_loader.load_lora_weights(
221-
pipeline,
222-
lora_config["lora_model_name_or_path"][0],
223-
high_noise_weight_name=lora_config["high_noise_weight_name"][0],
224-
low_noise_weight_name=lora_config["low_noise_weight_name"][0],
225-
rank=lora_config["rank"][0],
226-
scale=lora_config["scale"][0],
227-
scan_layers=config.scan_layers,
228-
)
214+
for i in range(len(lora_config["lora_model_name_or_path"])):
215+
pipeline = lora_loader.load_lora_weights(
216+
pipeline,
217+
lora_config["lora_model_name_or_path"][i],
218+
high_noise_weight_name=lora_config["high_noise_weight_name"][i],
219+
low_noise_weight_name=lora_config["low_noise_weight_name"][i],
220+
rank=lora_config["rank"][i],
221+
scale=lora_config["scale"][i],
222+
scan_layers=config.scan_layers,
223+
dtype=config.weights_dtype,
224+
)
229225

230226
s0 = time.perf_counter()
231227

src/maxdiffusion/loaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414

1515
from .lora_pipeline import StableDiffusionLoraLoaderMixin
1616
from .flux_lora_pipeline import FluxLoraLoaderMixin
17-
from .wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
17+
from .wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from . import lora_conversion_utils
2323

2424

25-
class Wan2_1NnxLoraLoader(LoRABaseMixin):
25+
class Wan2_1NNXLoraLoader(LoRABaseMixin):
2626
"""
2727
Handles loading LoRA weights into NNX-based WAN 2.1 model.
2828
Assumes WAN pipeline contains 'transformer'
@@ -37,6 +37,7 @@ def load_lora_weights(
3737
rank: int,
3838
scale: float = 1.0,
3939
scan_layers: bool = False,
40+
dtype: str = "float32",
4041
**kwargs,
4142
):
4243
"""
@@ -53,14 +54,14 @@ def translate_fn(nnx_path_str):
5354
if hasattr(pipeline, "transformer") and transformer_weight_name:
5455
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
5556
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56-
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn)
57+
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
5758
else:
5859
max_logging.log("transformer not found or no weight name provided for LoRA.")
5960

6061
return pipeline
6162

6263

63-
class Wan2_2NnxLoraLoader(LoRABaseMixin):
64+
class Wan2_2NNXLoraLoader(LoRABaseMixin):
6465
"""
6566
Handles loading LoRA weights into NNX-based WAN 2.2 model.
6667
Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer'
@@ -76,6 +77,7 @@ def load_lora_weights(
7677
rank: int,
7778
scale: float = 1.0,
7879
scan_layers: bool = False,
80+
dtype: str = "float32",
7981
**kwargs,
8082
):
8183
"""
@@ -92,15 +94,15 @@ def translate_fn(nnx_path_str: str):
9294
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
9395
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
9496
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs)
95-
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn)
97+
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
9698
else:
9799
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
98100

99101
# Handle low noise model
100102
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
101103
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
102104
l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs)
103-
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn)
105+
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn, dtype=dtype)
104106
else:
105107
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")
106108

0 commit comments

Comments
 (0)