Skip to content

Commit 223ad70

Browse files
add posoitional arg names to hf_hub_download
1 parent 4d1775f commit 223ad70

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
150150
elif hf_download:
151151
# download the index file for sharded models.
152152
index_file_path = hf_hub_download(
153-
pretrained_model_name_or_path, subfolder, filename,
153+
pretrained_model_name_or_path, subfolder=subfolder, filename=filename,
154154
)
155155
with jax.default_device(device):
156156
# open the index file.
@@ -166,7 +166,7 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
166166
if local_files:
167167
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
168168
else:
169-
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file)
169+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
170170
# now get all the filenames for the model that need downloading
171171
max_logging.log(f"Load and port Wan 2.1 transformer on {device}")
172172

@@ -214,7 +214,7 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device:
214214
raise FileNotFoundError(f"File {ckpt_path} not found for local directory.")
215215
elif hf_download:
216216
ckpt_path = hf_hub_download(
217-
pretrained_model_name_or_path, subfolder, filename
217+
pretrained_model_name_or_path, subfolder=subfolder, filename=filename
218218
)
219219
max_logging.log(f"Load and port Wan 2.1 VAE on {device}")
220220
with jax.default_device(device):

0 commit comments

Comments
 (0)