Skip to content

Commit 1ce98d2

Browse files
committed
debug
1 parent 55c4692 commit 1ce98d2

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

src/maxdiffusion/tests/ltx2_vae_parity_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,22 @@ def fix_keys(d):
7878
print("Retrying update with fixed keys...")
7979
nnx.update(model, fixed_params)
8080

81+
# Debug: Check Model Weights Shapes
82+
print("\n--- Model Weights Debug ---")
83+
try:
84+
if hasattr(model, 'encoder'):
85+
conv_in_kernel = model.encoder.conv_in.conv.kernel.value
86+
print(f"Encoder conv_in kernel shape: {conv_in_kernel.shape}")
87+
88+
# Check first resnet
89+
if len(model.encoder.down_blocks) > 0:
90+
resnet0 = model.encoder.down_blocks[0].resnets[0]
91+
conv1_kernel = resnet0.conv1.conv.kernel.value
92+
print(f"Encoder down_blocks[0].resnets[0].conv1 kernel shape: {conv1_kernel.shape}")
93+
except Exception as e:
94+
print(f"Could not inspect weights: {e}")
95+
print("---------------------------\n")
96+
8197
# 3. Create Inputs
8298
print("Creating deterministic input...")
8399
# Shape: (Batch, Frames, Height, Width, Channels) for JAX

0 commit comments

Comments
 (0)