Skip to content

Commit 2132468

Browse files
committed
debug_audio_vae
1 parent db4acec commit 2132468

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

debug_audio_vae.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,28 @@ def flatten(d, parent_key=()):
6767
for k in pt_keys[:20]:
6868
print(k)
6969

70+
print("\nSample Encoder Keys:")
71+
enc_keys = [k for k in pt_keys if "encoder" in k]
72+
for k in enc_keys[:20]:
73+
print(k)
74+
75+
# Check specific encoder key shape if possible?
76+
# Can't easily check shape here without loading tensor, but load_sharded_checkpoint loads all.
77+
# tensors is already loaded.
78+
79+
print("\nChecking Encoder Down Block 0 shape:")
80+
if "encoder.down.0.block.0.conv1.conv.weight" in tensors:
81+
print("encoder.down.0.block.0.conv1.conv.weight:", tensors["encoder.down.0.block.0.conv1.conv.weight"].shape)
82+
if "encoder.down.0.block.1.conv1.conv.weight" in tensors:
83+
print("encoder.down.0.block.1.conv1.conv.weight:", tensors["encoder.down.0.block.1.conv1.conv.weight"].shape)
84+
if "encoder.down.1.block.0.conv1.conv.weight" in tensors:
85+
print("encoder.down.1.block.0.conv1.conv.weight:", tensors["encoder.down.1.block.0.conv1.conv.weight"].shape)
86+
87+
print("\nChecking Decoder Up Block 0 shape:")
88+
if "decoder.up.0.block.0.conv1.conv.weight" in tensors:
89+
print("decoder.up.0.block.0.conv1.conv.weight:", tensors["decoder.up.0.block.0.conv1.conv.weight"].shape)
90+
91+
7092
print("\nTesting Renaming Logic...")
7193
renamed_keys = []
7294
for k in pt_keys:

0 commit comments

Comments
 (0)