|
| 1 | +""" |
| 2 | +Copyright 2026 Google LLC |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
1 | 17 | import json |
2 | 18 | import torch |
3 | 19 | import jax |
4 | 20 | import jax.numpy as jnp |
5 | 21 | from maxdiffusion import max_logging |
6 | 22 | from huggingface_hub import hf_hub_download |
| 23 | +from huggingface_hub.utils import EntryNotFoundError |
7 | 24 | from safetensors import safe_open |
8 | 25 | from flax.traverse_util import unflatten_dict, flatten_dict |
9 | 26 | from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) |
@@ -101,12 +118,12 @@ def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device): |
101 | 118 | with safe_open(shard_path, framework="pt") as f: |
102 | 119 | for k in f.keys(): |
103 | 120 | tensors[k] = torch2jax(f.get_tensor(k)) |
104 | | - except Exception: |
| 121 | + except EntryNotFoundError: |
105 | 122 | # Fallback to single file |
106 | 123 | filename = "diffusion_pytorch_model.safetensors" |
107 | 124 | try: |
108 | 125 | ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename) |
109 | | - except Exception: |
| 126 | + except EntryNotFoundError: |
110 | 127 | filename = "diffusion_pytorch_model.bin" |
111 | 128 | ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename) |
112 | 129 |
|
|
0 commit comments