Skip to content

Commit 98a22aa

Browse files
committed
new fix
1 parent f7f9c83 commit 98a22aa

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/tests/ltx2_vae_parity_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def fix_keys(d):
110110
print("Running Flax forward pass...")
111111
# Call the model
112112
# Note: default deterministic=True, causal=True/False depending on init
113-
jax_recon = model(jax_input, sample_posterior=False)
113+
jax_recon = model(jax_input, sample_posterior=False).sample
114114

115115
# 5. Print Output Stats
116116
print("\nOutput Stats:")

0 commit comments

Comments
 (0)