Skip to content

Commit e78a91b

Browse files
committed
adding parity check file
1 parent ceabd07 commit e78a91b

1 file changed

Lines changed: 38 additions & 6 deletions

File tree

src/maxdiffusion/tests/ltx2_vae_parity_test.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import jax
55
import jax.numpy as jnp
66
from flax import nnx
7+
from flax import traverse_util
78
from flax.training import orbax_utils
89
import orbax.checkpoint
910
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
@@ -39,16 +40,50 @@ def test_ltx2_vae_parity():
3940
graphdef, state = nnx.split(model)
4041
params = state.filter(nnx.Param)
4142

42-
# Load into structure
43+
# Load without 'item' to avoid structure mismatch errors with State vs Dict
4344
if not os.path.exists(ckpt_path):
4445
print(f"Error: Checkpoint path {ckpt_path} does not exist.")
4546
return
4647

47-
# Load without 'item' to avoid structure mismatch errors with State vs Dict
4848
loaded_params = checkpointer.restore(ckpt_path)
4949

50+
# Debug: Print structure of loaded_params
51+
print("Loaded params type:", type(loaded_params))
52+
if isinstance(loaded_params, dict):
53+
print("Loaded keys sample:", list(loaded_params.keys())[:5])
54+
# Check encoder down_blocks if present
55+
if 'encoder' in loaded_params and 'down_blocks' in loaded_params['encoder']:
56+
print("Encoder down_blocks keys:", list(loaded_params['encoder']['down_blocks'].keys()))
57+
first_key = next(iter(loaded_params['encoder']['down_blocks']))
58+
print(f"Key type: {type(first_key)}")
59+
5060
# Merge back
51-
nnx.update(model, loaded_params)
61+
try:
62+
nnx.update(model, loaded_params)
63+
except KeyError as e:
64+
print(f"Caught KeyError during update: {e}")
65+
print("Attempting to fix integer keys...")
66+
# If keys are strings but should be integers (or vice versa), fix them
67+
# nnx.List expects integer keys.
68+
# If orbax loaded them as strings '0', '1', we need to convert to int 0, 1.
69+
70+
def fix_keys(d):
71+
new_d = {}
72+
for k, v in d.items():
73+
if isinstance(v, dict):
74+
v = fix_keys(v)
75+
76+
# Check if key is a string digit
77+
if isinstance(k, str) and k.isdigit():
78+
new_k = int(k)
79+
else:
80+
new_k = k
81+
new_d[new_k] = v
82+
return new_d
83+
84+
fixed_params = fix_keys(loaded_params)
85+
print("Retrying update with fixed keys...")
86+
nnx.update(model, fixed_params)
5287

5388
# 3. Create Inputs
5489
print("Creating deterministic input...")
@@ -66,9 +101,6 @@ def test_ltx2_vae_parity():
66101
print("Running Flax forward pass...")
67102
# model(sample, sample_posterior=False) -> should return reconstructed image
68103

69-
# We use valid key for potential noise injection (though disabled in config)
70-
rngs = nnx.Rngs(0)
71-
72104
# Call the model
73105
# Note: default deterministic=True, causal=True/False depending on init
74106
jax_recon = model(jax_input, sample_posterior=False, deterministic=True)

0 commit comments

Comments
 (0)