We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 91f9c97 commit 1dbb3b8Copy full SHA for 1dbb3b8
1 file changed
src/maxdiffusion/models/ltx2/ltx2_utils.py
@@ -340,7 +340,8 @@ def load_vae_weights(
340
else:
341
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
342
343
- # print(f"Loaded VAE Key: {flax_key}")
+ if "decoder" in flax_key_str and "up_blocks" in flax_key_str and "0" in flax_key_str and "resnets" in flax_key_str and "2" in flax_key_str and "conv2" in flax_key_str:
344
+ print(f"DEBUG: Processing target key. Final flax_key: {flax_key}")
345
346
print(f"Total VAE keys loaded: {len(flax_state_dict)}")
347
validate_flax_state_dict(eval_shapes, flax_state_dict)
0 commit comments