Skip to content

Commit ee30474

Browse files
committed
adding parity check file
1 parent 77fc2db commit ee30474

1 file changed

Lines changed: 34 additions & 81 deletions

File tree

Lines changed: 34 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,16 @@
11

22
import os
3-
import torch
43
import numpy as np
54
import jax
65
import jax.numpy as jnp
76
from flax import nnx
87
from flax.training import orbax_utils
98
import orbax.checkpoint
10-
from diffusers import AutoencoderKLLTXVideo
119
from maxdiffusion.models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
12-
from maxdiffusion import pyconfig
13-
from maxdiffusion import max_utils
1410

1511
def test_ltx2_vae_parity():
16-
# 1. Load PyTorch Model
17-
print("Loading PyTorch model...")
18-
pt_model = AutoencoderKLLTXVideo.from_pretrained(
19-
"Lightricks/LTX-2",
20-
subfolder="vae",
21-
torch_dtype=torch.float32
22-
)
23-
pt_model.eval()
24-
25-
# 2. Load Flax Model
26-
print("Loading Flax model...")
12+
# 1. Load Flax Model
13+
print("Initializing MaxDiffusion model...")
2714
# Initialize with same config as conversion
2815
model = LTX2VideoAutoencoderKL(
2916
in_channels=3,
@@ -53,89 +40,55 @@ def test_ltx2_vae_parity():
5340
params = state.filter(nnx.Param)
5441

5542
# Load into structure
43+
if not os.path.exists(ckpt_path):
44+
print(f"Error: Checkpoint path {ckpt_path} does not exist.")
45+
return
46+
5647
loaded_params = checkpointer.restore(ckpt_path, item=params)
5748

5849
# Merge back
5950
nnx.update(model, loaded_params)
6051

6152
# 3. Create Inputs
62-
# Shape: (Batch, Channels, Frames, Height, Width)
63-
# LTX-2 uses (B, C, F, H, W) for PT
64-
# MaxDiffusion uses (B, F, H, W, C) for JAX
53+
print("Creating deterministic input...")
54+
# Shape: (Batch, Frames, Height, Width, Channels) for JAX
55+
# Using fixed seed for reproducibility
56+
key = jax.random.PRNGKey(42)
57+
B, F, H, W, C = 1, 17, 64, 64, 3
6558

66-
B, C, F, H, W = 1, 3, 17, 64, 64 # Small input for speed (F=1 + 16 for patching?)
67-
# F should be compatible with temporal patch (1) and scaling.
68-
# PT model expects specific structure?
59+
jax_input = jax.random.normal(key, (B, F, H, W, C), dtype=jnp.float32)
6960

70-
torch.manual_seed(42)
71-
pt_input = torch.randn(B, C, F, H, W, dtype=torch.float32)
72-
73-
# 4. Run PyTorch
74-
print("Running PyTorch forward pass...")
75-
with torch.no_grad():
76-
pt_output = pt_model(pt_input, sample_posterior=True).latent_dist.mode() # Compare encoded latents? Or full round trip?
77-
# Let's compare ENCODER first as valid middle ground, then full decode
78-
79-
pt_enc_dist = pt_model.encode(pt_input).latent_dist
80-
pt_latents = pt_enc_dist.mode()
81-
82-
pt_recon = pt_model.decode(pt_latents).sample
61+
print(f"Input Shape: {jax_input.shape}")
62+
print(f"Input Stats: Mean={jax_input.mean():.6f}, Std={jax_input.std():.6f}, Min={jax_input.min():.6f}, Max={jax_input.max():.6f}")
8363

84-
# 5. Run Flax
64+
# 4. Run Flax
8565
print("Running Flax forward pass...")
86-
# Convert input to JAX format: (B, F, H, W, C)
87-
jax_input = jnp.array(pt_input.permute(0, 2, 3, 4, 1).numpy())
88-
89-
# Encode
90-
# AutoencoderKL usually returns distribution or sample
91-
# LTX2VideoAutoencoderKL.encode returns (params, rngs) -> but we are calling methods directly if split?
92-
# No, we called nnx.merge or update. model is stateful.
93-
94-
# model.encode(sample, return_dict=False) -> (mean, logvar) ??
95-
# Checking implementation of encode in autoencoder_kl_ltx2.py check...
66+
# model(sample, sample_posterior=False) -> should return reconstructed image
9667

68+
# We use valid key for potential noise injection (though disabled in config)
9769
rngs = nnx.Rngs(0)
98-
# We need to call it appropriately.
99-
# The class has __call__ which does encode -> decode (round trip)
10070

101-
# Round trip match
102-
jax_recon = model(jax_input, sample_posterior=False, deterministic=True) # mode() equivalent?
71+
# Call the model
72+
# Note: default deterministic=True, causal=True/False depending on init
73+
jax_recon = model(jax_input, sample_posterior=False, deterministic=True)
10374

104-
# Check encode separately if possible, but __call__ is easiest for verified end-to-end
75+
# 5. Print Output Stats
76+
print("\nOutput Stats:")
77+
print(f"Output Shape: {jax_recon.shape}")
78+
print(f"Output Mean: {jax_recon.mean():.6f}")
79+
print(f"Output Std: {jax_recon.std():.6f}")
80+
print(f"Output Min: {jax_recon.min():.6f}")
81+
print(f"Output Max: {jax_recon.max():.6f}")
10582

106-
# For fair comparison with mode(), we need to tell JAX to sample mode.
107-
# If sample_posterior=True, it samples.
108-
# If sample_posterior=False, it returns mode (usually).
109-
110-
# 6. Compare Outputs
111-
print("Comparing outputs...")
112-
113-
# JAX output: (B, F, H, W, C) -> PT: (B, C, F, H, W)
114-
jax_recon_pt = torch.tensor(np.array(jax_recon)).permute(0, 4, 1, 2, 3)
115-
116-
diff = (pt_recon - jax_recon_pt).abs()
117-
mae = diff.mean().item()
118-
max_diff = diff.max().item()
119-
120-
print(f"Mean Absolute Error: {mae}")
121-
print(f"Max Difference: {max_diff}")
122-
123-
if max_diff > 1e-3: # Loose tolerance initially
124-
print("❌ Parity Check FAILED")
125-
else:
126-
print("✅ Parity Check PASSED")
127-
12883
# Also Check Encoder Latents
129-
print("\nComparing Encoder Latents...")
130-
# Flax Encode
131-
# model.encode returns diagonal_gaussian_distribution
84+
print("\nEncoder Latents Stats:")
13285
posterior = model.encode(jax_input)
133-
jax_latents = posterior.mode()
134-
135-
jax_latents_pt = torch.tensor(np.array(jax_latents)).permute(0, 4, 1, 2, 3)
136-
diff_latents = (pt_latents - jax_latents_pt).abs()
137-
print(f"Latents MAE: {diff_latents.mean().item()}")
138-
print(f"Latents Max Diff: {diff_latents.max().item()}")
86+
# posterior is DiagonalGaussianDistribution
87+
# Check mode
88+
latents = posterior.mode()
89+
print(f"Latents Shape: {latents.shape}")
90+
print(f"Latents Mean: {latents.mean():.6f}")
91+
print(f"Latents Std: {latents.std():.6f}")
13992

14093
if __name__ == "__main__":
14194
test_ltx2_vae_parity()

0 commit comments

Comments
 (0)