Skip to content

Commit 27f7577

Browse files
committed
debug_audio_vae
1 parent 137d41a commit 27f7577

1 file changed

Lines changed: 211 additions & 0 deletions

File tree

debug_audio_vae.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
2+
import jax
3+
import sys
4+
import os
5+
6+
# Add src to path
7+
sys.path.append(os.path.join(os.getcwd(), "src"))
8+
9+
from maxdiffusion.models.ltx2.audio_vae import FlaxAutoencoderKLLTX2Audio
10+
from maxdiffusion.models.ltx2.ltx2_utils import load_audio_vae_weights, rename_for_ltx2_audio_vae
11+
from maxdiffusion.utils import load_sharded_checkpoint
12+
from flax import nnx
13+
14+
def debug_keys():
15+
print("Initializing Model...")
16+
config = {
17+
"base_channels": 128,
18+
"ch_mult": (1, 2, 4),
19+
"double_z": True,
20+
"dropout": 0.0,
21+
"in_channels": 2,
22+
"latent_channels": 8,
23+
"mel_bins": 64,
24+
"mel_hop_length": 160,
25+
"mid_block_add_attention": False,
26+
"norm_type": "pixel",
27+
"num_res_blocks": 2,
28+
"output_channels": 2,
29+
"resolution": 256,
30+
"sample_rate": 16000,
31+
"rngs": nnx.Rngs(0)
32+
}
33+
34+
with jax.default_device(jax.devices("cpu")[0]):
35+
model = FlaxAutoencoderKLLTX2Audio(**config)
36+
37+
state = nnx.state(model)
38+
eval_shapes = state.to_pure_dict()
39+
40+
# Print some expected Flax keys
41+
print("\nSample Flax Keys (Expected):")
42+
43+
def flatten(d, parent_key=()):
44+
items = []
45+
for k, v in d.items():
46+
new_key = parent_key + (k,)
47+
if isinstance(v, dict):
48+
items.extend(flatten(v, new_key))
49+
else:
50+
items.append(new_key)
51+
return items
52+
53+
flax_keys = flatten(eval_shapes)
54+
for k in flax_keys[:20]:
55+
print(k)
56+
57+
print("\nTotal Flax Keys:", len(flax_keys))
58+
59+
# Load PyTorch keys
60+
print("\nLoading PyTorch SafeTensors Keys...")
61+
pretrained_model_name_or_path = "Lightricks/LTX-2"
62+
subfolder = "audio_vae"
63+
64+
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, "cpu")
65+
pt_keys = list(tensors.keys())
66+
67+
print("\nSample PyTorch Keys (Original):")
68+
for k in pt_keys[:20]:
69+
print(k)
70+
71+
print("\nTesting Renaming Logic...")
72+
renamed_keys = []
73+
for k in pt_keys:
74+
renamed = rename_for_ltx2_audio_vae(k)
75+
renamed_keys.append(renamed)
76+
if "mid_block.resnets.0.conv1.weight" in k:
77+
print(f"Renaming check: {k} -> {renamed}")
78+
79+
# Check for misaligned expected keys
80+
# specific missing ones
81+
targets = [
82+
('decoder', 'mid_block1', 'conv1', 'conv', 'bias'),
83+
('decoder', 'mid_block1', 'conv1', 'conv', 'kernel'),
84+
]
85+
86+
print("\nSearching for targets in RENAMED keys:")
87+
for t in targets:
88+
t_str = ".".join([str(x) for x in t])
89+
found = False
90+
for rk in renamed_keys:
91+
# We need to simulate the structure mapping logic too?
92+
# rename_for_ltx2_audio_vae only does string replacement,
93+
# load_audio_vae_weights does structural mapping (mid_block -> mid_block1)
94+
pass
95+
96+
# Let's verify specific renaming for mid_block1
97+
# PyTorch: decoder.mid_block.resnets.0.conv1.weight
98+
# My rename: decoder.mid_block.resnets.0.conv1.conv.kernel
99+
# My logic in load_audio_vae_weights:
100+
# if "mid_block.resnets.0" in k: replace with mid_block1
101+
# -> decoder.mid_block1.conv1.conv.kernel
102+
# Flax expected: ('decoder', 'mid_block1', 'conv1', 'conv', 'kernel')
103+
104+
# Is it possible that 'mid_block.resnets.0' is NOT in the key?
105+
# Maybe it's 'mid_block.resnets.0.conv1.weight'? Yes.
106+
107+
# We will print all RENAMED and STRUCTURED keys produced by our logic
108+
print("\nGenerating final Flax keys from PyTorch keys using current logic...")
109+
final_keys = set()
110+
111+
for pt_key in pt_keys:
112+
key = rename_for_ltx2_audio_vae(pt_key)
113+
114+
# Determine conversion to tuple (Same logic as in ltx2_utils.py)
115+
parts = key.split(".")
116+
flax_key_parts = []
117+
for part in parts:
118+
if part.isdigit():
119+
flax_key_parts.append(int(part))
120+
else:
121+
flax_key_parts.append(part)
122+
flax_key = tuple(flax_key_parts)
123+
124+
if "mid_block" in pt_key:
125+
if "mid_block.resnets.0" in pt_key:
126+
flax_key_str = ".".join([str(x) for x in flax_key])
127+
flax_key_str = flax_key_str.replace("mid_block.resnets.0", "mid_block1")
128+
elif "mid_block.resnets.1" in pt_key:
129+
flax_key_str = ".".join([str(x) for x in flax_key])
130+
flax_key_str = flax_key_str.replace("mid_block.resnets.1", "mid_block2")
131+
elif "mid_block.attentions.0" in pt_key:
132+
flax_key_str = ".".join([str(x) for x in flax_key])
133+
flax_key_str = flax_key_str.replace("mid_block.attentions.0", "mid_attn")
134+
else:
135+
flax_key_str = ".".join([str(x) for x in flax_key])
136+
137+
parts = flax_key_str.split(".")
138+
flax_key_parts = []
139+
for part in parts:
140+
if part.isdigit():
141+
flax_key_parts.append(int(part))
142+
else:
143+
flax_key_parts.append(part)
144+
flax_key = tuple(flax_key_parts)
145+
146+
if "down_blocks" in key:
147+
key_str = ".".join([str(x) for x in flax_key])
148+
if "resnets" in key_str:
149+
key_str = key_str.replace("down_blocks", "down_stages")
150+
key_str = key_str.replace("resnets", "blocks")
151+
elif "attentions" in key_str:
152+
key_str = key_str.replace("down_blocks", "down_stages")
153+
key_str = key_str.replace("attentions", "attns")
154+
elif "downsamplers" in key_str:
155+
key_str = key_str.replace("down_blocks", "down_stages")
156+
key_str = key_str.replace("downsamplers.0", "downsample")
157+
158+
parts = key_str.split(".")
159+
flax_key_parts = []
160+
for part in parts:
161+
if part.isdigit():
162+
flax_key_parts.append(int(part))
163+
else:
164+
flax_key_parts.append(part)
165+
flax_key = tuple(flax_key_parts)
166+
167+
if "up_blocks" in key:
168+
key_str = ".".join([str(x) for x in flax_key])
169+
if "resnets" in key_str:
170+
key_str = key_str.replace("up_blocks", "up_stages")
171+
key_str = key_str.replace("resnets", "blocks")
172+
elif "attentions" in key_str:
173+
key_str = key_str.replace("up_blocks", "up_stages")
174+
key_str = key_str.replace("attentions", "attns")
175+
elif "upsamplers" in key_str:
176+
key_str = key_str.replace("up_blocks", "up_stages")
177+
key_str = key_str.replace("upsamplers.0", "upsample")
178+
179+
parts = key_str.split(".")
180+
flax_key_parts = []
181+
for part in parts:
182+
if part.isdigit():
183+
flax_key_parts.append(int(part))
184+
else:
185+
flax_key_parts.append(part)
186+
flax_key = tuple(flax_key_parts)
187+
188+
final_keys.add(flax_key)
189+
190+
print("\nComparing Final Keys vs Expected Keys...")
191+
flax_keys_set = set(flax_keys)
192+
missing = flax_keys_set - final_keys
193+
194+
# Filter stats
195+
filtered_missing = []
196+
for k in missing:
197+
k_str = [str(x) for x in k]
198+
if "dropout" in k_str or "rngs" in k_str:
199+
continue
200+
filtered_missing.append(k)
201+
202+
print(f"Missing Keys (Count: {len(filtered_missing)}):")
203+
for k in sorted(filtered_missing)[:20]:
204+
print(k)
205+
206+
print("\nExtra Keys (Count: {len(final_keys - flax_keys_set)}):")
207+
for k in sorted(list(final_keys - flax_keys_set))[:20]:
208+
print(k)
209+
210+
if __name__ == "__main__":
211+
debug_keys()

0 commit comments

Comments
 (0)