Skip to content

Commit 7f75eeb

Browse files
committed
check_encoder_keys
1 parent 97d4df9 commit 7f75eeb

1 file changed

Lines changed: 54 additions & 0 deletions

File tree

check_encoder_keys.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
from safetensors import safe_open
3+
from huggingface_hub import snapshot_download
4+
import os
5+
6+
def check_encoder():
7+
resume_from_checkpoint = "Lightricks/LTX-Video"
8+
cache_dir = os.path.join(os.path.expanduser("~"), ".cache/huggingface/hub")
9+
10+
print(f"Scanning cache dir: {cache_dir}")
11+
12+
vae_path = None
13+
# Try to find the specific file
14+
search_path = os.path.join(cache_dir, "models--Lightricks--LTX-Video/snapshots")
15+
if os.path.exists(search_path):
16+
for root, dirs, files in os.walk(search_path):
17+
if "vae" in root and "diffusion_pytorch_model.safetensors" in files:
18+
vae_path = os.path.join(root, "diffusion_pytorch_model.safetensors")
19+
break
20+
21+
if not vae_path:
22+
print("VAE checkpoint not found in cache. Downloading...")
23+
# Fallback to downloading if not found (though user seems to have it)
24+
try:
25+
download_path = snapshot_download(repo_id=resume_from_checkpoint, allow_patterns=["vae/*"])
26+
vae_path = os.path.join(download_path, "vae", "diffusion_pytorch_model.safetensors")
27+
except Exception as e:
28+
print(f"Failed to download: {e}")
29+
return
30+
31+
print(f"Found VAE checkpoint at: {vae_path}")
32+
33+
try:
34+
with safe_open(vae_path, framework="pt", device="cpu") as f:
35+
keys = f.keys()
36+
encoder_keys = [k for k in keys if "encoder" in k]
37+
decoder_keys = [k for k in keys if "decoder" in k]
38+
39+
print(f"Total keys: {len(keys)}")
40+
print(f"Encoder keys count: {len(encoder_keys)}")
41+
print(f"Decoder keys count: {len(decoder_keys)}")
42+
43+
if len(encoder_keys) > 0:
44+
print("First 5 encoder keys:")
45+
for k in encoder_keys[:5]:
46+
print(k)
47+
else:
48+
print("NO ENCODER KEYS FOUND.")
49+
50+
except Exception as e:
51+
print(f"Error reading checkpoint: {e}")
52+
53+
if __name__ == "__main__":
54+
check_encoder()

0 commit comments

Comments
 (0)