Skip to content

Commit a04e1b4

Browse files
committed
Styling changes
1 parent f1f3fc1 commit a04e1b4

2 files changed

Lines changed: 9 additions & 2 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,12 @@ def run(config, pipeline=None, filename_prefix=""):
193193
pipeline, _, _ = checkpoint_loader.load_checkpoint()
194194

195195
# If LoRA is specified, inject layers and load weights.
196-
if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]:
196+
if (
197+
config.enable_lora
198+
and hasattr(config, "lora_config")
199+
and config.lora_config
200+
and config.lora_config["lora_model_name_or_path"]
201+
):
197202
if model_key == WAN2_1:
198203
lora_loader = Wan2_1NNXLoraLoader()
199204
lora_config = config.lora_config

src/maxdiffusion/models/lora_nnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,9 @@ def merge_lora(model: nnx.Module, state_dict: dict, rank: int, scale: float, tra
267267
)
268268

269269

270-
def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None, dtype: str = "float32"):
270+
def merge_lora_for_scanned(
271+
model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None, dtype: str = "float32"
272+
):
271273
"""
272274
Device-Side Optimized Merge for Scanned Layers.
273275
Now supports diff and diff_b.

0 commit comments

Comments
 (0)