@@ -36,10 +36,6 @@ def test_ltx2_vae_parity():
3636
3737 checkpointer = orbax .checkpoint .Checkpointer (orbax .checkpoint .PyTreeCheckpointHandler ())
3838
39- # recreate split to get structure
40- graphdef , state = nnx .split (model )
41- params = state .filter (nnx .Param )
42-
4339 # Load without 'item' to avoid structure mismatch errors with State vs Dict
4440 if not os .path .exists (ckpt_path ):
4541 print (f"Error: Checkpoint path { ckpt_path } does not exist." )
@@ -63,9 +59,6 @@ def test_ltx2_vae_parity():
6359 except KeyError as e :
6460 print (f"Caught KeyError during update: { e } " )
6561 print ("Attempting to fix integer keys..." )
66- # If keys are strings but should be integers (or vice versa), fix them
67- # nnx.List expects integer keys.
68- # If orbax loaded them as strings '0', '1', we need to convert to int 0, 1.
6962
7063 def fix_keys (d ):
7164 new_d = {}
@@ -99,11 +92,9 @@ def fix_keys(d):
9992
10093 # 4. Run Flax
10194 print ("Running Flax forward pass..." )
102- # model(sample, sample_posterior=False) -> should return reconstructed image
103-
10495 # Call the model
10596 # Note: default deterministic=True, causal=True/False depending on init
106- jax_recon = model (jax_input , sample_posterior = False , deterministic = True )
97+ jax_recon = model (jax_input , sample_posterior = False )
10798
10899 # 5. Print Output Stats
109100 print ("\n Output Stats:" )
0 commit comments