Skip to content

Commit 5fc7b3d

Browse files
committed
Errors catching corrections
1 parent 629f515 commit 5fc7b3d

2 files changed

Lines changed: 6 additions & 4 deletions

File tree

src/maxdiffusion/generate_ltx2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from maxdiffusion import pyconfig, max_logging, max_utils
2323
from absl import app
2424
from google.cloud import storage
25+
from google.api_core.exceptions import GoogleAPIError
2526
import flax
2627
from maxdiffusion.utils.export_utils import export_to_video_with_audio
2728

@@ -48,8 +49,8 @@ def upload_video_to_gcs(output_dir: str, video_path: str):
4849
blob.upload_from_filename(source_file_path)
4950
max_logging.log(f"Upload complete {source_file_path}.")
5051

51-
except Exception as e:
52-
max_logging.log(f"An error occurred: {e}")
52+
except GoogleAPIError as e:
53+
max_logging.log(f"A storage error occurred during upload: {e}")
5354

5455

5556
def delete_file(file_path: str):

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import jax.numpy as jnp
55
from maxdiffusion import max_logging
66
from huggingface_hub import hf_hub_download
7+
from huggingface_hub.utils import EntryNotFoundError
78
from safetensors import safe_open
89
from flax.traverse_util import unflatten_dict, flatten_dict
910
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
@@ -101,12 +102,12 @@ def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device):
101102
with safe_open(shard_path, framework="pt") as f:
102103
for k in f.keys():
103104
tensors[k] = torch2jax(f.get_tensor(k))
104-
except Exception:
105+
except EntryNotFoundError:
105106
# Fallback to single file
106107
filename = "diffusion_pytorch_model.safetensors"
107108
try:
108109
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
109-
except Exception:
110+
except EntryNotFoundError:
110111
filename = "diffusion_pytorch_model.bin"
111112
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
112113

0 commit comments

Comments
 (0)