@@ -196,14 +196,15 @@ def user_init(raw_keys):
196196
197197 # Orbax doesn't save the tokenizer params, instead it loads them from the pretrained_model_name_or_path
198198 raw_keys ["tokenizer_model_name_or_path" ] = raw_keys ["pretrained_model_name_or_path" ]
199+ tmp_dir = raw_keys .get ("tmp_dir" , "/tmp" )
199200 if "gs://" in raw_keys ["tokenizer_model_name_or_path" ]:
200- raw_keys ["pretrained_model_name_or_path" ] = max_utils .download_blobs (raw_keys ["pretrained_model_name_or_path" ], "/tmp" )
201+ raw_keys ["pretrained_model_name_or_path" ] = max_utils .download_blobs (raw_keys ["pretrained_model_name_or_path" ], tmp_dir )
201202 if "gs://" in raw_keys ["pretrained_model_name_or_path" ]:
202- raw_keys ["pretrained_model_name_or_path" ] = max_utils .download_blobs (raw_keys ["pretrained_model_name_or_path" ], "/tmp" )
203+ raw_keys ["pretrained_model_name_or_path" ] = max_utils .download_blobs (raw_keys ["pretrained_model_name_or_path" ], tmp_dir )
203204 if "gs://" in raw_keys ["unet_checkpoint" ]:
204- raw_keys ["unet_checkpoint" ] = max_utils .download_blobs (raw_keys ["unet_checkpoint" ], "/tmp" )
205+ raw_keys ["unet_checkpoint" ] = max_utils .download_blobs (raw_keys ["unet_checkpoint" ], tmp_dir )
205206 if "gs://" in raw_keys ["tokenizer_model_name_or_path" ]:
206- raw_keys ["tokenizer_model_name_or_path" ] = max_utils .download_blobs (raw_keys ["tokenizer_model_name_or_path" ], "/tmp" )
207+ raw_keys ["tokenizer_model_name_or_path" ] = max_utils .download_blobs (raw_keys ["tokenizer_model_name_or_path" ], tmp_dir )
207208 if "gs://" in raw_keys ["dataset_name" ]:
208209 raw_keys ["dataset_name" ] = max_utils .download_blobs (raw_keys ["dataset_name" ], raw_keys ["dataset_save_location" ])
209210 raw_keys ["dataset_save_location" ] = raw_keys ["dataset_name" ]
0 commit comments