|
28 | 28 | from google.cloud import storage |
29 | 29 | import flax |
30 | 30 | from maxdiffusion.common_types import WAN2_1, WAN2_2 |
31 | | -from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader |
| 31 | +from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader |
32 | 32 |
|
33 | 33 |
|
34 | 34 | def upload_video_to_gcs(output_dir: str, video_path: str): |
@@ -191,39 +191,40 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): |
191 | 191 | pipeline, _, _ = checkpoint_loader.load_checkpoint() |
192 | 192 |
|
193 | 193 | # If LoRA is specified, inject layers and load weights. |
194 | | - if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]: |
| 194 | + if ( |
| 195 | + config.enable_lora |
| 196 | + and hasattr(config, "lora_config") |
| 197 | + and config.lora_config |
| 198 | + and config.lora_config["lora_model_name_or_path"] |
| 199 | + ): |
195 | 200 | if model_key == WAN2_1: |
196 | | - lora_loader = Wan2_1NnxLoraLoader() |
| 201 | + lora_loader = Wan2_1NNXLoraLoader() |
197 | 202 | lora_config = config.lora_config |
198 | | - |
199 | | - if len(lora_config["lora_model_name_or_path"]) > 1: |
200 | | - max_logging.log("Found multiple LoRAs in config, but only loading the first one.") |
201 | | - |
202 | | - pipeline = lora_loader.load_lora_weights( |
203 | | - pipeline, |
204 | | - lora_config["lora_model_name_or_path"][0], |
205 | | - transformer_weight_name=lora_config["weight_name"][0], |
206 | | - rank=lora_config["rank"][0], |
207 | | - scale=lora_config["scale"][0], |
208 | | - scan_layers=config.scan_layers, |
209 | | - ) |
| 203 | + for i in range(len(lora_config["lora_model_name_or_path"])): |
| 204 | + pipeline = lora_loader.load_lora_weights( |
| 205 | + pipeline, |
| 206 | + lora_config["lora_model_name_or_path"][i], |
| 207 | + transformer_weight_name=lora_config["weight_name"][i], |
| 208 | + rank=lora_config["rank"][i], |
| 209 | + scale=lora_config["scale"][i], |
| 210 | + scan_layers=config.scan_layers, |
| 211 | + dtype=config.weights_dtype, |
| 212 | + ) |
210 | 213 |
|
211 | 214 | if model_key == WAN2_2: |
212 | | - lora_loader = Wan2_2NnxLoraLoader() |
| 215 | + lora_loader = Wan2_2NNXLoraLoader() |
213 | 216 | lora_config = config.lora_config |
214 | | - |
215 | | - if len(lora_config["lora_model_name_or_path"]) > 1: |
216 | | - max_logging.log("Found multiple LoRAs in config, but only loading the first one.") |
217 | | - |
218 | | - pipeline = lora_loader.load_lora_weights( |
219 | | - pipeline, |
220 | | - lora_config["lora_model_name_or_path"][0], |
221 | | - high_noise_weight_name=lora_config["high_noise_weight_name"][0], |
222 | | - low_noise_weight_name=lora_config["low_noise_weight_name"][0], |
223 | | - rank=lora_config["rank"][0], |
224 | | - scale=lora_config["scale"][0], |
225 | | - scan_layers=config.scan_layers, |
226 | | - ) |
| 217 | + for i in range(len(lora_config["lora_model_name_or_path"])): |
| 218 | + pipeline = lora_loader.load_lora_weights( |
| 219 | + pipeline, |
| 220 | + lora_config["lora_model_name_or_path"][i], |
| 221 | + high_noise_weight_name=lora_config["high_noise_weight_name"][i], |
| 222 | + low_noise_weight_name=lora_config["low_noise_weight_name"][i], |
| 223 | + rank=lora_config["rank"][i], |
| 224 | + scale=lora_config["scale"][i], |
| 225 | + scan_layers=config.scan_layers, |
| 226 | + dtype=config.weights_dtype, |
| 227 | + ) |
227 | 228 |
|
228 | 229 | s0 = time.perf_counter() |
229 | 230 |
|
|
0 commit comments