Skip to content

Commit 841167d

Browse files
committed
fix
1 parent 7f75eeb commit 841167d

1 file changed

Lines changed: 75 additions & 33 deletions

File tree

inspect_vae_checkpoint.py

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,84 @@
11

2-
import os
3-
from huggingface_hub import snapshot_download
42
from safetensors import safe_open
5-
import torch
3+
from huggingface_hub import snapshot_download
4+
import os
65

7-
def inspect_checkpoint():
6+
def inspect_structure():
7+
resume_from_checkpoint = "Lightricks/LTX-Video"
8+
cache_dir = os.path.join(os.path.expanduser("~"), ".cache/huggingface/hub")
9+
10+
vae_path = None
11+
search_path = os.path.join(cache_dir, "models--Lightricks--LTX-Video/snapshots")
12+
if os.path.exists(search_path):
13+
for root, dirs, files in os.walk(search_path):
14+
if "vae" in root and "diffusion_pytorch_model.safetensors" in files:
15+
vae_path = os.path.join(root, "diffusion_pytorch_model.safetensors")
16+
break
17+
18+
if not vae_path:
19+
print("VAE checkpoint not found.")
20+
return
21+
22+
print(f"Analyzing checkpoint: {vae_path}")
23+
24+
structure = {
25+
"encoder": {"down_blocks": {}, "mid_block": 0},
26+
"decoder": {"up_blocks": {}, "mid_block": 0}
27+
}
28+
829
try:
9-
# Allow looking in the user's cache
10-
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
11-
print(f"Scanning cache dir: {cache_dir}")
12-
13-
# We know the model is Lightricks/LTX-Video
14-
# We look for snapshots
15-
repo_id = "Lightricks/LTX-Video"
16-
17-
# Try to find it
18-
# We can use hf_hub_download to find the path without downloading if it exists
19-
from huggingface_hub import hf_hub_download
20-
21-
try:
22-
vae_path = hf_hub_download(repo_id, subfolder="vae", filename="diffusion_pytorch_model.safetensors")
23-
print(f"Found VAE checkpoint at: {vae_path}")
30+
with safe_open(vae_path, framework="pt", device="cpu") as f:
31+
keys = f.keys()
32+
print(f"Total keys: {len(keys)}")
2433

25-
with safe_open(vae_path, framework="pt") as f:
26-
keys = f.keys()
27-
print(f"Total keys in VAE checkpoint: {len(keys)}")
28-
print("Sample keys:")
29-
for i, k in enumerate(keys):
30-
if i < 20:
31-
print(k)
32-
if "resnets" in k and "up_blocks" in k and i % 10 == 0:
33-
print(f"Resnet key sample: {k}")
34+
for key in keys:
35+
parts = key.split(".")
36+
if "resnets" not in parts:
37+
continue
38+
39+
try:
40+
resnets_idx = parts.index("resnets")
41+
# The next part should be the index
42+
if len(parts) > resnets_idx + 1 and parts[resnets_idx + 1].isdigit():
43+
block_idx = int(parts[resnets_idx + 1])
3444

35-
except Exception as e:
36-
print(f"Could not find VAE checkpoint via hf_hub_download: {e}")
37-
45+
if parts[0] == "encoder":
46+
if "down_blocks" in parts:
47+
down_idx_loc = parts.index("down_blocks") + 1
48+
down_idx = int(parts[down_idx_loc])
49+
if down_idx not in structure["encoder"]["down_blocks"]:
50+
structure["encoder"]["down_blocks"][down_idx] = 0
51+
structure["encoder"]["down_blocks"][down_idx] = max(structure["encoder"]["down_blocks"][down_idx], block_idx + 1)
52+
elif "mid_block" in parts:
53+
structure["encoder"]["mid_block"] = max(structure["encoder"]["mid_block"], block_idx + 1)
54+
55+
elif parts[0] == "decoder":
56+
if "up_blocks" in parts:
57+
up_idx_loc = parts.index("up_blocks") + 1
58+
up_idx = int(parts[up_idx_loc])
59+
if up_idx not in structure["decoder"]["up_blocks"]:
60+
structure["decoder"]["up_blocks"][up_idx] = 0
61+
structure["decoder"]["up_blocks"][up_idx] = max(structure["decoder"]["up_blocks"][up_idx], block_idx + 1)
62+
elif "mid_block" in parts:
63+
structure["decoder"]["mid_block"] = max(structure["decoder"]["mid_block"], block_idx + 1)
64+
65+
except (ValueError, IndexError) as e:
66+
# print(f"Skipping key {key}: {e}")
67+
continue
68+
69+
print("\nDuced VAE Structure (Layers per block):")
70+
print("Encoder:")
71+
for i in sorted(structure["encoder"]["down_blocks"].keys()):
72+
print(f" Down Block {i}: {structure['encoder']['down_blocks'][i]} layers")
73+
print(f" Mid Block: {structure['encoder']['mid_block']} layers")
74+
75+
print("Decoder:")
76+
for i in sorted(structure["decoder"]["up_blocks"].keys()):
77+
print(f" Up Block {i}: {structure['decoder']['up_blocks'][i]} layers")
78+
print(f" Mid Block: {structure['decoder']['mid_block']} layers")
79+
3880
except Exception as e:
39-
print(f"An error occurred: {e}")
81+
print(f"Error reading checkpoint: {e}")
4082

4183
if __name__ == "__main__":
42-
inspect_checkpoint()
84+
inspect_structure()

0 commit comments

Comments
 (0)