@@ -212,18 +212,33 @@ def load_base_wan_transformer(
212212 device = jax .local_devices (backend = device )[0 ]
213213 filename = "diffusion_pytorch_model.safetensors.index.json"
214214 local_files = False
215+
216+ # Only rank 0 downloads; others wait for cache to be populated
217+ process_index = jax .process_index ()
215218 if os .path .isdir (pretrained_model_name_or_path ):
216219 index_file_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
217220 if not os .path .isfile (index_file_path ):
218221 raise FileNotFoundError (f"File { index_file_path } not found for local directory." )
219222 local_files = True
220223 elif hf_download :
221- # download the index file for sharded models.
222- index_file_path = hf_hub_download (
223- pretrained_model_name_or_path ,
224- subfolder = subfolder ,
225- filename = filename ,
226- )
224+ # Only rank 0 downloads; synchronize across all ranks
225+ if process_index == 0 :
226+ # download the index file for sharded models.
227+ index_file_path = hf_hub_download (
228+ pretrained_model_name_or_path ,
229+ subfolder = subfolder ,
230+ filename = filename ,
231+ )
232+ jax .experimental .multihost_utils .sync_global_devices ("model_index_download" )
233+
234+ if process_index != 0 :
235+ # Non-rank-0 processes wait and use the cached path
236+ index_file_path = hf_hub_download (
237+ pretrained_model_name_or_path ,
238+ subfolder = subfolder ,
239+ filename = filename ,
240+ force_download = False , # Use cache, don't download
241+ )
227242 with jax .default_device (device ):
228243 # open the index file.
229244 with open (index_file_path , "r" ) as f :
@@ -238,7 +253,19 @@ def load_base_wan_transformer(
238253 if local_files :
239254 ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
240255 else :
241- ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
256+ # Only rank 0 downloads new files; others use cached versions
257+ if process_index == 0 :
258+ ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = model_file )
259+ jax .experimental .multihost_utils .sync_global_devices (f"model_download_{ model_file } " )
260+
261+ if process_index != 0 :
262+ # Non-rank-0: use cached version
263+ ckpt_shard_path = hf_hub_download (
264+ pretrained_model_name_or_path ,
265+ subfolder = subfolder ,
266+ filename = model_file ,
267+ force_download = False , # Use cache
268+ )
242269 # now get all the filenames for the model that need downloading
243270 max_logging .log (f"Load and port { pretrained_model_name_or_path } { subfolder } on { device } " )
244271
@@ -304,12 +331,25 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device:
304331 device = jax .devices (device )[0 ]
305332 subfolder = "vae"
306333 filename = "diffusion_pytorch_model.safetensors"
334+ process_index = jax .process_index ()
335+
307336 if os .path .isdir (pretrained_model_name_or_path ):
308337 ckpt_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
309338 if not os .path .isfile (ckpt_path ):
310339 raise FileNotFoundError (f"File { ckpt_path } not found for local directory." )
311340 elif hf_download :
312- ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
341+ # Only rank 0 downloads; others use cache
342+ if process_index == 0 :
343+ ckpt_path = hf_hub_download (pretrained_model_name_or_path , subfolder = subfolder , filename = filename )
344+ jax .experimental .multihost_utils .sync_global_devices ("vae_download" )
345+
346+ if process_index != 0 :
347+ ckpt_path = hf_hub_download (
348+ pretrained_model_name_or_path ,
349+ subfolder = subfolder ,
350+ filename = filename ,
351+ force_download = False , # Use cache
352+ )
313353 max_logging .log (f"Load and port { pretrained_model_name_or_path } VAE on { device } " )
314354 with jax .default_device (device ):
315355 if ckpt_path is not None :
0 commit comments