Skip to content

Commit 1dbb3b8

Browse files
committed
fix
1 parent 91f9c97 commit 1dbb3b8

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ def load_vae_weights(
340340
else:
341341
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
342342

343-
# print(f"Loaded VAE Key: {flax_key}")
343+
if "decoder" in flax_key_str and "up_blocks" in flax_key_str and "0" in flax_key_str and "resnets" in flax_key_str and "2" in flax_key_str and "conv2" in flax_key_str:
344+
print(f"DEBUG: Processing target key. Final flax_key: {flax_key}")
344345

345346
print(f"Total VAE keys loaded: {len(flax_state_dict)}")
346347
validate_flax_state_dict(eval_shapes, flax_state_dict)

0 commit comments

Comments
 (0)