Skip to content

Commit 34229db

Browse files
committed
reproduce_vae_mapping
1 parent 9e4038b commit 34229db

2 files changed

Lines changed: 89 additions & 3517 deletions

File tree

reproduce_vae_mapping.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
2+
import sys
3+
import os
4+
import torch
5+
import jax.numpy as jnp
6+
from flax.traverse_util import flatten_dict, unflatten_dict
7+
from maxdiffusion.models.modeling_flax_pytorch_utils import rename_key, rename_key_and_reshape_tensor
8+
from maxdiffusion.models.ltx2.ltx2_utils import _tuple_str_to_int
9+
10+
def test_vae_key(pt_key):
11+
print(f"\nProcessing Checkpoint Key: {pt_key}")
12+
13+
# Logic copied/adapted from load_vae_weights in ltx2_utils.py
14+
renamed_pt_key = rename_key(pt_key)
15+
# print(f"After rename_key: {renamed_pt_key}")
16+
17+
pt_tuple_key = tuple(renamed_pt_key.split("."))
18+
19+
pt_list = []
20+
21+
for i, part in enumerate(pt_tuple_key):
22+
if "_" in part and part.split("_")[-1].isdigit():
23+
name = "_".join(part.split("_")[:-1])
24+
idx = int(part.split("_")[-1])
25+
26+
if name == "resnets":
27+
pt_list.append("resnets")
28+
pt_list.append(str(idx))
29+
elif name == "upsamplers":
30+
pt_list.append("upsampler")
31+
elif name in ["down_blocks", "up_blocks", "downsamplers"]:
32+
pt_list.append(name)
33+
pt_list.append(str(idx))
34+
else:
35+
pt_list.append(part)
36+
elif part == "upsampler":
37+
pt_list.append("upsampler")
38+
elif part in ["conv1", "conv2", "conv"]:
39+
pt_list.append(part)
40+
# Logic from ltx2_utils.py
41+
if i + 1 < len(pt_tuple_key) and pt_tuple_key[i+1] == "conv":
42+
pass
43+
elif pt_list[-1] == "conv":
44+
pass
45+
elif len(pt_list) >= 2 and pt_list[-2] == "conv":
46+
pass
47+
elif part == "conv":
48+
pass
49+
else:
50+
pt_list.append("conv")
51+
else:
52+
pt_list.append(part)
53+
54+
pt_tuple_key = tuple(pt_list)
55+
print(f"Constructed PT Tuple Key: {pt_tuple_key}")
56+
57+
# Mock random_flax_state_dict for rename_key_and_reshape_tensor check
58+
# We pretend the target key exists
59+
# If pt_tuple_key ends in 'weight', we look for 'kernel'
60+
# If logic generates 'conv1.conv', we check compatibility
61+
62+
mock_flax_key = list(pt_tuple_key)
63+
if mock_flax_key[-1] == "weight":
64+
mock_flax_key[-1] = "kernel"
65+
if mock_flax_key[-1] == "bias":
66+
pass
67+
68+
mock_flax_key_tuple = tuple(mock_flax_key)
69+
random_flax_state_dict = {mock_flax_key_tuple: 1} # Dummy Exists
70+
71+
# dummy tensor
72+
import torch
73+
tensor = torch.zeros(1)
74+
75+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
76+
flax_key = _tuple_str_to_int(flax_key)
77+
78+
print(f"Final Flax Key: {flax_key}")
79+
80+
if __name__ == "__main__":
81+
# Test cases from missing keys log
82+
test_keys = [
83+
"decoder.up_blocks.1.resnets.1.conv1.weight",
84+
"encoder.down_blocks.0.resnets.0.conv1.weight",
85+
"decoder.mid_block.resnets.0.conv1.weight", # Example guessing structure
86+
]
87+
88+
for k in test_keys:
89+
test_vae_key(k)

0 commit comments

Comments
 (0)