44import jax
55import jax .numpy as jnp
66from flax import nnx
7+ from flax import traverse_util
78from flax .training import orbax_utils
89import orbax .checkpoint
910from maxdiffusion .models .ltx2 .autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
@@ -39,16 +40,50 @@ def test_ltx2_vae_parity():
3940 graphdef , state = nnx .split (model )
4041 params = state .filter (nnx .Param )
4142
42- # Load into structure
43+ # Load without 'item' to avoid structure mismatch errors with State vs Dict
4344 if not os .path .exists (ckpt_path ):
4445 print (f"Error: Checkpoint path { ckpt_path } does not exist." )
4546 return
4647
47- # Load without 'item' to avoid structure mismatch errors with State vs Dict
4848 loaded_params = checkpointer .restore (ckpt_path )
4949
50+ # Debug: Print structure of loaded_params
51+ print ("Loaded params type:" , type (loaded_params ))
52+ if isinstance (loaded_params , dict ):
53+ print ("Loaded keys sample:" , list (loaded_params .keys ())[:5 ])
54+ # Check encoder down_blocks if present
55+ if 'encoder' in loaded_params and 'down_blocks' in loaded_params ['encoder' ]:
56+ print ("Encoder down_blocks keys:" , list (loaded_params ['encoder' ]['down_blocks' ].keys ()))
57+ first_key = next (iter (loaded_params ['encoder' ]['down_blocks' ]))
58+ print (f"Key type: { type (first_key )} " )
59+
5060 # Merge back
51- nnx .update (model , loaded_params )
61+ try :
62+ nnx .update (model , loaded_params )
63+ except KeyError as e :
64+ print (f"Caught KeyError during update: { e } " )
65+ print ("Attempting to fix integer keys..." )
66+ # If keys are strings but should be integers (or vice versa), fix them
67+ # nnx.List expects integer keys.
68+ # If orbax loaded them as strings '0', '1', we need to convert to int 0, 1.
69+
70+ def fix_keys (d ):
71+ new_d = {}
72+ for k , v in d .items ():
73+ if isinstance (v , dict ):
74+ v = fix_keys (v )
75+
76+ # Check if key is a string digit
77+ if isinstance (k , str ) and k .isdigit ():
78+ new_k = int (k )
79+ else :
80+ new_k = k
81+ new_d [new_k ] = v
82+ return new_d
83+
84+ fixed_params = fix_keys (loaded_params )
85+ print ("Retrying update with fixed keys..." )
86+ nnx .update (model , fixed_params )
5287
5388 # 3. Create Inputs
5489 print ("Creating deterministic input..." )
@@ -66,9 +101,6 @@ def test_ltx2_vae_parity():
66101 print ("Running Flax forward pass..." )
67102 # model(sample, sample_posterior=False) -> should return reconstructed image
68103
69- # We use valid key for potential noise injection (though disabled in config)
70- rngs = nnx .Rngs (0 )
71-
72104 # Call the model
73105 # Note: default deterministic=True, causal=True/False depending on init
74106 jax_recon = model (jax_input , sample_posterior = False , deterministic = True )
0 commit comments