We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f7f9c83 commit 98a22aaCopy full SHA for 98a22aa
1 file changed
src/maxdiffusion/tests/ltx2_vae_parity_test.py
@@ -110,7 +110,7 @@ def fix_keys(d):
110
print("Running Flax forward pass...")
111
# Call the model
112
# Note: default deterministic=True, causal=True/False depending on init
113
- jax_recon = model(jax_input, sample_posterior=False)
+ jax_recon = model(jax_input, sample_posterior=False).sample
114
115
# 5. Print Output Stats
116
print("\nOutput Stats:")
0 commit comments