@@ -212,33 +212,18 @@ def load_base_wan_transformer(
212212 device = jax .local_devices (backend = device )[0 ]
213213 filename = "diffusion_pytorch_model.safetensors.index.json"
214214 local_files = False
215-
216- # Only rank 0 downloads; others wait for cache to be populated
217- process_index = jax .process_index ()
218215 if os .path .isdir (pretrained_model_name_or_path ):
219216 index_file_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
220217 if not os .path .isfile (index_file_path ):
221218 raise FileNotFoundError (f"File { index_file_path } not found for local directory." )
222219 local_files = True
223220 elif hf_download :
224- # Only rank 0 downloads; synchronize across all ranks
225- if process_index == 0 :
226- # download the index file for sharded models.
227- index_file_path = hf_hub_download (
228- pretrained_model_name_or_path ,
229- subfolder = subfolder ,
230- filename = filename ,
231- )
232- jax .experimental .multihost_utils .sync_global_devices ("model_index_download" )
233-
234- if process_index != 0 :
235- # Non-rank-0 processes wait and use the cached path
236- index_file_path = hf_hub_download (
237- pretrained_model_name_or_path ,
238- subfolder = subfolder ,
239- filename = filename ,
240- force_download = False , # Use cache, don't download
241- )
221+ # download the index file for sharded models.
222+ index_file_path = hf_hub_download (
223+ pretrained_model_name_or_path ,
224+ subfolder = subfolder ,
225+ filename = filename ,
226+ )
242227 with jax .default_device (device ):
243228 # open the index file.
244229 with open (index_file_path , "r" ) as f :
@@ -253,19 +238,7 @@ def load_base_wan_transformer(
253238 if local_files :
254239 ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
255240 else :
256- # Only rank 0 downloads new files; others use cached versions
257- if process_index == 0 :
258- ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
259- jax .experimental .multihost_utils .sync_global_devices (f"model_download_{ model_file } " )
260-
261- if process_index != 0 :
262- # Non-rank-0: use cached version
263- ckpt_shard_path = hf_hub_download (
264- pretrained_model_name_or_path ,
265- subfolder = subfolder ,
266- filename = model_file ,
267- force_download = False , # Use cache
268- )
241+ ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
269242 # now get all the filenames for the model that need downloading
270243 max_logging .log (f"Load and port { pretrained_model_name_or_path } { subfolder } on { device } " )
271244
@@ -331,25 +304,12 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device:
331304 device = jax .devices (device )[0 ]
332305 subfolder = "vae"
333306 filename = "diffusion_pytorch_model.safetensors"
334- process_index = jax .process_index ()
335-
336307 if os .path .isdir (pretrained_model_name_or_path ):
337308 ckpt_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
338309 if not os .path .isfile (ckpt_path ):
339310 raise FileNotFoundError (f"File { ckpt_path } not found for local directory." )
340311 elif hf_download :
341- # Only rank 0 downloads; others use cache
342- if process_index == 0 :
343- ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
344- jax .experimental .multihost_utils .sync_global_devices ("vae_download" )
345-
346- if process_index != 0 :
347- ckpt_path = hf_hub_download (
348- pretrained_model_name_or_path ,
349- subfolder = subfolder ,
350- filename = filename ,
351- force_download = False , # Use cache
352- )
312+ ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
353313 max_logging .log (f"Load and port { pretrained_model_name_or_path } VAE on { device } " )
354314 with jax .default_device (device ):
355315 if ckpt_path is not None :
0 commit comments