|
28 | 28 | from google.cloud import storage |
29 | 29 | import flax |
30 | 30 | from maxdiffusion.common_types import WAN2_1, WAN2_2 |
31 | | -from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader |
| 31 | +from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader |
32 | 32 |
|
33 | 33 |
|
34 | 34 | def upload_video_to_gcs(output_dir: str, video_path: str): |
@@ -195,37 +195,33 @@ def run(config, pipeline=None, filename_prefix=""): |
195 | 195 | # If LoRA is specified, inject layers and load weights. |
196 | 196 | if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]: |
197 | 197 | if model_key == WAN2_1: |
198 | | - lora_loader = Wan2_1NnxLoraLoader() |
| 198 | + lora_loader = Wan2_1NNXLoraLoader() |
199 | 199 | lora_config = config.lora_config |
200 | | - |
201 | | - if len(lora_config["lora_model_name_or_path"]) > 1: |
202 | | - max_logging.log("Found multiple LoRAs in config, but only loading the first one.") |
203 | | - |
204 | | - pipeline = lora_loader.load_lora_weights( |
205 | | - pipeline, |
206 | | - lora_config["lora_model_name_or_path"][0], |
207 | | - transformer_weight_name=lora_config["weight_name"][0], |
208 | | - rank=lora_config["rank"][0], |
209 | | - scale=lora_config["scale"][0], |
210 | | - scan_layers=config.scan_layers, |
211 | | - ) |
| 200 | + for i in range(len(lora_config["lora_model_name_or_path"])): |
| 201 | + pipeline = lora_loader.load_lora_weights( |
| 202 | + pipeline, |
| 203 | + lora_config["lora_model_name_or_path"][i], |
| 204 | + transformer_weight_name=lora_config["weight_name"][i], |
| 205 | + rank=lora_config["rank"][i], |
| 206 | + scale=lora_config["scale"][i], |
| 207 | + scan_layers=config.scan_layers, |
| 208 | + dtype=config.weights_dtype, |
| 209 | + ) |
212 | 210 |
|
213 | 211 | if model_key == WAN2_2: |
214 | | - lora_loader = Wan2_2NnxLoraLoader() |
| 212 | + lora_loader = Wan2_2NNXLoraLoader() |
215 | 213 | lora_config = config.lora_config |
216 | | - |
217 | | - if len(lora_config["lora_model_name_or_path"]) > 1: |
218 | | - max_logging.log("Found multiple LoRAs in config, but only loading the first one.") |
219 | | - |
220 | | - pipeline = lora_loader.load_lora_weights( |
221 | | - pipeline, |
222 | | - lora_config["lora_model_name_or_path"][0], |
223 | | - high_noise_weight_name=lora_config["high_noise_weight_name"][0], |
224 | | - low_noise_weight_name=lora_config["low_noise_weight_name"][0], |
225 | | - rank=lora_config["rank"][0], |
226 | | - scale=lora_config["scale"][0], |
227 | | - scan_layers=config.scan_layers, |
228 | | - ) |
| 214 | + for i in range(len(lora_config["lora_model_name_or_path"])): |
| 215 | + pipeline = lora_loader.load_lora_weights( |
| 216 | + pipeline, |
| 217 | + lora_config["lora_model_name_or_path"][i], |
| 218 | + high_noise_weight_name=lora_config["high_noise_weight_name"][i], |
| 219 | + low_noise_weight_name=lora_config["low_noise_weight_name"][i], |
| 220 | + rank=lora_config["rank"][i], |
| 221 | + scale=lora_config["scale"][i], |
| 222 | + scan_layers=config.scan_layers, |
| 223 | + dtype=config.weights_dtype, |
| 224 | + ) |
229 | 225 |
|
230 | 226 | s0 = time.perf_counter() |
231 | 227 |
|
|
0 commit comments