|
1 | 1 |
|
2 | | -import os |
3 | | -from huggingface_hub import snapshot_download |
4 | 2 | from safetensors import safe_open |
5 | | -import torch |
| 3 | +from huggingface_hub import snapshot_download |
| 4 | +import os |
6 | 5 |
|
7 | | -def inspect_checkpoint(): |
| 6 | +def inspect_structure(): |
| 7 | + resume_from_checkpoint = "Lightricks/LTX-Video" |
| 8 | + cache_dir = os.path.join(os.path.expanduser("~"), ".cache/huggingface/hub") |
| 9 | + |
| 10 | + vae_path = None |
| 11 | + search_path = os.path.join(cache_dir, "models--Lightricks--LTX-Video/snapshots") |
| 12 | + if os.path.exists(search_path): |
| 13 | + for root, dirs, files in os.walk(search_path): |
| 14 | + if "vae" in root and "diffusion_pytorch_model.safetensors" in files: |
| 15 | + vae_path = os.path.join(root, "diffusion_pytorch_model.safetensors") |
| 16 | + break |
| 17 | + |
| 18 | + if not vae_path: |
| 19 | + print("VAE checkpoint not found.") |
| 20 | + return |
| 21 | + |
| 22 | + print(f"Analyzing checkpoint: {vae_path}") |
| 23 | + |
| 24 | + structure = { |
| 25 | + "encoder": {"down_blocks": {}, "mid_block": 0}, |
| 26 | + "decoder": {"up_blocks": {}, "mid_block": 0} |
| 27 | + } |
| 28 | + |
8 | 29 | try: |
9 | | - # Allow looking in the user's cache |
10 | | - cache_dir = os.path.expanduser("~/.cache/huggingface/hub") |
11 | | - print(f"Scanning cache dir: {cache_dir}") |
12 | | - |
13 | | - # We know the model is Lightricks/LTX-Video |
14 | | - # We look for snapshots |
15 | | - repo_id = "Lightricks/LTX-Video" |
16 | | - |
17 | | - # Try to find it |
18 | | - # We can use hf_hub_download to find the path without downloading if it exists |
19 | | - from huggingface_hub import hf_hub_download |
20 | | - |
21 | | - try: |
22 | | - vae_path = hf_hub_download(repo_id, subfolder="vae", filename="diffusion_pytorch_model.safetensors") |
23 | | - print(f"Found VAE checkpoint at: {vae_path}") |
| 30 | + with safe_open(vae_path, framework="pt", device="cpu") as f: |
| 31 | + keys = f.keys() |
| 32 | + print(f"Total keys: {len(keys)}") |
24 | 33 |
|
25 | | - with safe_open(vae_path, framework="pt") as f: |
26 | | - keys = f.keys() |
27 | | - print(f"Total keys in VAE checkpoint: {len(keys)}") |
28 | | - print("Sample keys:") |
29 | | - for i, k in enumerate(keys): |
30 | | - if i < 20: |
31 | | - print(k) |
32 | | - if "resnets" in k and "up_blocks" in k and i % 10 == 0: |
33 | | - print(f"Resnet key sample: {k}") |
| 34 | + for key in keys: |
| 35 | + parts = key.split(".") |
| 36 | + if "resnets" not in parts: |
| 37 | + continue |
| 38 | + |
| 39 | + try: |
| 40 | + resnets_idx = parts.index("resnets") |
| 41 | + # The next part should be the index |
| 42 | + if len(parts) > resnets_idx + 1 and parts[resnets_idx + 1].isdigit(): |
| 43 | + block_idx = int(parts[resnets_idx + 1]) |
34 | 44 |
|
35 | | - except Exception as e: |
36 | | - print(f"Could not find VAE checkpoint via hf_hub_download: {e}") |
37 | | - |
| 45 | + if parts[0] == "encoder": |
| 46 | + if "down_blocks" in parts: |
| 47 | + down_idx_loc = parts.index("down_blocks") + 1 |
| 48 | + down_idx = int(parts[down_idx_loc]) |
| 49 | + if down_idx not in structure["encoder"]["down_blocks"]: |
| 50 | + structure["encoder"]["down_blocks"][down_idx] = 0 |
| 51 | + structure["encoder"]["down_blocks"][down_idx] = max(structure["encoder"]["down_blocks"][down_idx], block_idx + 1) |
| 52 | + elif "mid_block" in parts: |
| 53 | + structure["encoder"]["mid_block"] = max(structure["encoder"]["mid_block"], block_idx + 1) |
| 54 | + |
| 55 | + elif parts[0] == "decoder": |
| 56 | + if "up_blocks" in parts: |
| 57 | + up_idx_loc = parts.index("up_blocks") + 1 |
| 58 | + up_idx = int(parts[up_idx_loc]) |
| 59 | + if up_idx not in structure["decoder"]["up_blocks"]: |
| 60 | + structure["decoder"]["up_blocks"][up_idx] = 0 |
| 61 | + structure["decoder"]["up_blocks"][up_idx] = max(structure["decoder"]["up_blocks"][up_idx], block_idx + 1) |
| 62 | + elif "mid_block" in parts: |
| 63 | + structure["decoder"]["mid_block"] = max(structure["decoder"]["mid_block"], block_idx + 1) |
| 64 | + |
| 65 | + except (ValueError, IndexError) as e: |
| 66 | + # print(f"Skipping key {key}: {e}") |
| 67 | + continue |
| 68 | + |
| 69 | + print("\nDuced VAE Structure (Layers per block):") |
| 70 | + print("Encoder:") |
| 71 | + for i in sorted(structure["encoder"]["down_blocks"].keys()): |
| 72 | + print(f" Down Block {i}: {structure['encoder']['down_blocks'][i]} layers") |
| 73 | + print(f" Mid Block: {structure['encoder']['mid_block']} layers") |
| 74 | + |
| 75 | + print("Decoder:") |
| 76 | + for i in sorted(structure["decoder"]["up_blocks"].keys()): |
| 77 | + print(f" Up Block {i}: {structure['decoder']['up_blocks'][i]} layers") |
| 78 | + print(f" Mid Block: {structure['decoder']['mid_block']} layers") |
| 79 | + |
38 | 80 | except Exception as e: |
39 | | - print(f"An error occurred: {e}") |
| 81 | + print(f"Error reading checkpoint: {e}") |
40 | 82 |
|
41 | 83 | if __name__ == "__main__": |
42 | | - inspect_checkpoint() |
| 84 | + inspect_structure() |
0 commit comments