Skip to content

Commit 77fc2db

Browse files
committed
adding parity check file
1 parent c50a707 commit 77fc2db

1 file changed

Lines changed: 141 additions & 0 deletions

File tree

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
2+
import os
3+
import torch
4+
import numpy as np
5+
import jax
6+
import jax.numpy as jnp
7+
from flax import nnx
8+
from flax.training import orbax_utils
9+
import orbax.checkpoint
10+
from diffusers import AutoencoderKLLTXVideo
11+
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
12+
from maxdiffusion import pyconfig
13+
from maxdiffusion import max_utils
14+
15+
def test_ltx2_vae_parity():
16+
# 1. Load PyTorch Model
17+
print("Loading PyTorch model...")
18+
pt_model = AutoencoderKLLTXVideo.from_pretrained(
19+
"Lightricks/LTX-2",
20+
subfolder="vae",
21+
torch_dtype=torch.float32
22+
)
23+
pt_model.eval()
24+
25+
# 2. Load Flax Model
26+
print("Loading Flax model...")
27+
# Initialize with same config as conversion
28+
model = LTX2VideoAutoencoderKL(
29+
in_channels=3,
30+
out_channels=3,
31+
latent_channels=128,
32+
block_out_channels=(256, 512, 1024, 2048),
33+
decoder_block_out_channels=(256, 512, 1024),
34+
layers_per_block=(4, 6, 6, 2, 2),
35+
decoder_layers_per_block=(5, 5, 5, 5),
36+
spatio_temporal_scaling=(True, True, True, True),
37+
decoder_spatio_temporal_scaling=(True, True, True),
38+
decoder_inject_noise=(False, False, False, False),
39+
upsample_factor=(2, 2, 2),
40+
upsample_residual=(False, False, False),
41+
dtype=jnp.float32,
42+
rngs=nnx.Rngs(0)
43+
)
44+
45+
# Load checkpoint
46+
ckpt_path = os.path.abspath("ltx2_vae_checkpoint")
47+
print(f"Loading checkpoint from {ckpt_path}...")
48+
49+
checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
50+
51+
# recreate split to get structure
52+
graphdef, state = nnx.split(model)
53+
params = state.filter(nnx.Param)
54+
55+
# Load into structure
56+
loaded_params = checkpointer.restore(ckpt_path, item=params)
57+
58+
# Merge back
59+
nnx.update(model, loaded_params)
60+
61+
# 3. Create Inputs
62+
# Shape: (Batch, Channels, Frames, Height, Width)
63+
# LTX-2 uses (B, C, F, H, W) for PT
64+
# MaxDiffusion uses (B, F, H, W, C) for JAX
65+
66+
B, C, F, H, W = 1, 3, 17, 64, 64 # Small input for speed (F=1 + 16 for patching?)
67+
# F should be compatible with temporal patch (1) and scaling.
68+
# PT model expects specific structure?
69+
70+
torch.manual_seed(42)
71+
pt_input = torch.randn(B, C, F, H, W, dtype=torch.float32)
72+
73+
# 4. Run PyTorch
74+
print("Running PyTorch forward pass...")
75+
with torch.no_grad():
76+
pt_output = pt_model(pt_input, sample_posterior=True).latent_dist.mode() # Compare encoded latents? Or full round trip?
77+
# Let's compare ENCODER first as valid middle ground, then full decode
78+
79+
pt_enc_dist = pt_model.encode(pt_input).latent_dist
80+
pt_latents = pt_enc_dist.mode()
81+
82+
pt_recon = pt_model.decode(pt_latents).sample
83+
84+
# 5. Run Flax
85+
print("Running Flax forward pass...")
86+
# Convert input to JAX format: (B, F, H, W, C)
87+
jax_input = jnp.array(pt_input.permute(0, 2, 3, 4, 1).numpy())
88+
89+
# Encode
90+
# AutoencoderKL usually returns distribution or sample
91+
# LTX2VideoAutoencoderKL.encode returns (params, rngs) -> but we are calling methods directly if split?
92+
# No, we called nnx.merge or update. model is stateful.
93+
94+
# model.encode(sample, return_dict=False) -> (mean, logvar) ??
95+
# Checking implementation of encode in autoencoder_kl_ltx2.py check...
96+
97+
rngs = nnx.Rngs(0)
98+
# We need to call it appropriately.
99+
# The class has __call__ which does encode -> decode (round trip)
100+
101+
# Round trip match
102+
jax_recon = model(jax_input, sample_posterior=False, deterministic=True) # mode() equivalent?
103+
104+
# Check encode separately if possible, but __call__ is easiest for verified end-to-end
105+
106+
# For fair comparison with mode(), we need to tell JAX to sample mode.
107+
# If sample_posterior=True, it samples.
108+
# If sample_posterior=False, it returns mode (usually).
109+
110+
# 6. Compare Outputs
111+
print("Comparing outputs...")
112+
113+
# JAX output: (B, F, H, W, C) -> PT: (B, C, F, H, W)
114+
jax_recon_pt = torch.tensor(np.array(jax_recon)).permute(0, 4, 1, 2, 3)
115+
116+
diff = (pt_recon - jax_recon_pt).abs()
117+
mae = diff.mean().item()
118+
max_diff = diff.max().item()
119+
120+
print(f"Mean Absolute Error: {mae}")
121+
print(f"Max Difference: {max_diff}")
122+
123+
if max_diff > 1e-3: # Loose tolerance initially
124+
print("❌ Parity Check FAILED")
125+
else:
126+
print("✅ Parity Check PASSED")
127+
128+
# Also Check Encoder Latents
129+
print("\nComparing Encoder Latents...")
130+
# Flax Encode
131+
# model.encode returns diagonal_gaussian_distribution
132+
posterior = model.encode(jax_input)
133+
jax_latents = posterior.mode()
134+
135+
jax_latents_pt = torch.tensor(np.array(jax_latents)).permute(0, 4, 1, 2, 3)
136+
diff_latents = (pt_latents - jax_latents_pt).abs()
137+
print(f"Latents MAE: {diff_latents.mean().item()}")
138+
print(f"Latents Max Diff: {diff_latents.max().item()}")
139+
140+
if __name__ == "__main__":
141+
test_ltx2_vae_parity()

0 commit comments

Comments
 (0)