Skip to content

Commit 2a48490

Browse files
read local wan checkpoints.
1 parent b219048 commit 2a48490

1 file changed

Lines changed: 30 additions & 14 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import json
23
import torch
34
import jax
@@ -136,15 +137,22 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict,
136137
else:
137138
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download)
138139

139-
140140
def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
141141
device = jax.devices(device)[0]
142-
with jax.default_device(device):
143-
if hf_download:
144-
# download the index file for sharded models.
145-
index_file_path = hf_hub_download(
146-
pretrained_model_name_or_path, subfolder="transformer", filename="diffusion_pytorch_model.safetensors.index.json"
147-
)
142+
subfolder="transformer"
143+
filename="diffusion_pytorch_model.safetensors.index.json"
144+
local_files = False
145+
if os.path.isdir(pretrained_model_name_or_path):
146+
index_file_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
147+
if not os.path.isfile(index_file_path):
148+
raise FileNotFoundError(f"File {index_file_path} not found for local directory.")
149+
local_files = True
150+
elif hf_download:
151+
# download the index file for sharded models.
152+
index_file_path = hf_hub_download(
153+
pretrained_model_name_or_path, subfolder, filename,
154+
)
155+
with jax.default_device(device):
148156
# open the index file.
149157
with open(index_file_path, "r") as f:
150158
index_dict = json.load(f)
@@ -155,7 +163,10 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
155163
model_files = list(model_files)
156164
tensors = {}
157165
for model_file in model_files:
158-
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file)
166+
if local_files:
167+
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
168+
else:
169+
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file)
159170
# now get all the filenames for the model that need downloading
160171
max_logging.log(f"Load and port Wan 2.1 transformer on {device}")
161172

@@ -195,13 +206,18 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
195206

196207
def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
197208
device = jax.devices(device)[0]
209+
subfolder="vae"
210+
filename="diffusion_pytorch_model.safetensors"
211+
if os.path.isdir(pretrained_model_name_or_path):
212+
ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
213+
if not os.path.isfile(ckpt_path):
214+
raise FileNotFoundError(f"File {ckpt_path} not found for local directory.")
215+
elif hf_download:
216+
ckpt_path = hf_hub_download(
217+
pretrained_model_name_or_path, subfolder, filename
218+
)
219+
max_logging.log(f"Load and port Wan 2.1 VAE on {device}")
198220
with jax.default_device(device):
199-
if hf_download:
200-
ckpt_path = hf_hub_download(
201-
pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors"
202-
)
203-
max_logging.log(f"Load and port Wan 2.1 VAE on {device}")
204-
205221
if ckpt_path is not None:
206222
tensors = {}
207223
with safe_open(ckpt_path, framework="pt") as f:

0 commit comments

Comments
 (0)