Skip to content

Commit d6dca18

Browse files
committed
inspect vae structure
1 parent 54d3fe9 commit d6dca18

1 file changed

Lines changed: 55 additions & 0 deletions

File tree

inspect_vae_structure.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
import jax
3+
import jax.numpy as jnp
4+
from flax import nnx
5+
from flax.traverse_util import flatten_dict
6+
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
7+
8+
def inspect_structure():
9+
model = LTX2VideoAutoencoderKL(
10+
in_channels=3,
11+
out_channels=3,
12+
latent_channels=128,
13+
block_out_channels=(8, 16), # Small for speed
14+
layers_per_block=(1, 1), # Small for speed
15+
decoder_layers_per_block=(1, 1),
16+
spatio_temporal_scaling=(True, True),
17+
decoder_spatio_temporal_scaling=(True, True),
18+
decoder_inject_noise=(False, False),
19+
downsample_type=("spatial", "temporal"),
20+
upsample_residual=(True, True),
21+
upsample_factor=(2, 2)
22+
)
23+
24+
state = nnx.state(model)
25+
eval_shapes = state.to_pure_dict()
26+
flat_shapes = flatten_dict(eval_shapes)
27+
28+
print(f"Total keys: {len(flat_shapes)}")
29+
30+
# Check for resnets keys
31+
resnet_keys = [k for k in flat_shapes.keys() if "resnets" in [str(x) for x in k]]
32+
print("\nResnet keys sample:")
33+
for k in resnet_keys[:5]:
34+
print(f"{k}: {flat_shapes[k].shape}")
35+
36+
# Check for conv_in keys
37+
conv_in_keys = [k for k in flat_shapes.keys() if "conv_in" in [str(x) for x in k]]
38+
print("\nConv_in keys sample:")
39+
for k in conv_in_keys[:5]:
40+
print(f"{k}")
41+
42+
# Check for conv_out keys
43+
conv_out_keys = [k for k in flat_shapes.keys() if "conv_out" in [str(x) for x in k]]
44+
print("\nConv_out keys sample:")
45+
for k in conv_out_keys[:5]:
46+
print(f"{k}")
47+
48+
# Check for conv1 keys inside resnets
49+
conv1_keys = [k for k in resnet_keys if "conv1" in [str(x) for x in k]]
50+
print("\nConv1 keys inside resnets sample:")
51+
for k in conv1_keys[:5]:
52+
print(f"{k}")
53+
54+
if __name__ == "__main__":
55+
inspect_structure()

0 commit comments

Comments
 (0)