Skip to content

Commit 1536e5c

Browse files
committed
Fix
1 parent ac542bc commit 1536e5c

3 files changed

Lines changed: 4 additions & 12 deletions

File tree

src/maxdiffusion/configs/base_wan_lora_27b.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,9 @@ lightning_ckpt: ""
305305
lora_rank: 64
306306
# Values are lists to support multiple LoRA loading during inference in the future.
307307
lora_config: {
308-
lora_model_name_or_path: ["lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-250928"],
309-
high_noise_weight_name: ["high_noise_model.safetensors"],
310-
low_noise_weight_name: ["low_noise_model.safetensors"],
308+
lora_model_name_or_path: ["lightx2v/Wan2.2-Lightning"],
309+
high_noise_weight_name: ["Wan2.2-T2V-A14B-4steps-lora-250928/high_noise_model.safetensors"],
310+
low_noise_weight_name: ["Wan2.2-T2V-A14B-4steps-lora-250928/low_noise_model.safetensors"],
311311
adapter_name: ["wan22-lightning-lora"],
312312
scale: [1.0],
313313
from_pt: []

src/maxdiffusion/generate_wan.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def run(config, pipeline=None, filename_prefix=""):
166166
low_noise_weight_name=lora_config["low_noise_weight_name"][0],
167167
rank=config.lora_rank,
168168
scale=lora_config["scale"][0],
169-
rng=jax.random.key(config.seed),
170169
)
171170

172171
s0 = time.perf_counter()

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414

1515
"""NNX-based LoRA loader for WAN models."""
1616

17-
import re
1817
from flax import nnx
1918
import jax
20-
import jax.numpy as jnp
2119
from .lora_base import LoRABaseMixin
2220
from .lora_pipeline import StableDiffusionLoraLoaderMixin
2321
from ..models import lora_nnx
@@ -38,16 +36,11 @@ def load_lora_weights(
3836
low_noise_weight_name: str,
3937
rank: int,
4038
scale: float = 1.0,
41-
rng: jax.Array = None,
4239
**kwargs,
4340
):
4441
"""
45-
Injects LoRA layers into the pipeline and loads weights
46-
from a checkpoint.
42+
Merges LoRA weights into the pipeline from a checkpoint.
4743
"""
48-
if rng is None:
49-
rng = jax.random.key(0)
50-
5144
lora_loader = StableDiffusionLoraLoaderMixin()
5245

5346
# Handle high noise model

0 commit comments

Comments
 (0)