@@ -150,7 +150,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
150150 elif hf_download :
151151 # download the index file for sharded models.
152152 index_file_path = hf_hub_download (
153- pretrained_model_name_or_path , subfolder , filename ,
153+ pretrained_model_name_or_path , subfolder = subfolder , filename = filename ,
154154 )
155155 with jax .default_device (device ):
156156 # open the index file.
@@ -166,7 +166,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
166166 if local_files :
167167 ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
168168 else :
169- ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = "transformer" , filename = model_file )
169+ ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
170170 # now get all the filenames for the model that need downloading
171171 max_logging .log (f"Load and port Wan 2.1 transformer on { device } " )
172172
@@ -214,7 +214,7 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device:
214214 raise FileNotFoundError (f"File { ckpt_path } not found for local directory." )
215215 elif hf_download :
216216 ckpt_path = hf_hub_download (
217- pretrained_model_name_or_path , subfolder , filename
217+ pretrained_model_name_or_path , subfolder = subfolder , filename = filename
218218 )
219219 max_logging .log (f"Load and port Wan 2.1 VAE on { device } " )
220220 with jax .default_device (device ):
0 commit comments