Skip to content

Commit 55c4692

Browse files
committed
fix
1 parent e78a91b commit 55c4692

1 file changed

Lines changed: 1 addition & 10 deletions

File tree

src/maxdiffusion/tests/ltx2_vae_parity_test.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ def test_ltx2_vae_parity():
3636

3737
checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
3838

39-
# recreate split to get structure
40-
graphdef, state = nnx.split(model)
41-
params = state.filter(nnx.Param)
42-
4339
# Load without 'item' to avoid structure mismatch errors with State vs Dict
4440
if not os.path.exists(ckpt_path):
4541
print(f"Error: Checkpoint path {ckpt_path} does not exist.")
@@ -63,9 +59,6 @@ def test_ltx2_vae_parity():
6359
except KeyError as e:
6460
print(f"Caught KeyError during update: {e}")
6561
print("Attempting to fix integer keys...")
66-
# If keys are strings but should be integers (or vice versa), fix them
67-
# nnx.List expects integer keys.
68-
# If orbax loaded them as strings '0', '1', we need to convert to int 0, 1.
6962

7063
def fix_keys(d):
7164
new_d = {}
@@ -99,11 +92,9 @@ def fix_keys(d):
9992

10093
# 4. Run Flax
10194
print("Running Flax forward pass...")
102-
# model(sample, sample_posterior=False) -> should return reconstructed image
103-
10495
# Call the model
10596
# Note: default deterministic=True, causal=True/False depending on init
106-
jax_recon = model(jax_input, sample_posterior=False, deterministic=True)
97+
jax_recon = model(jax_input, sample_posterior=False)
10798

10899
# 5. Print Output Stats
109100
print("\nOutput Stats:")

0 commit comments

Comments
 (0)