Skip to content

Commit 8a1ef9c

Browse files
committed
fix
1 parent 34229db commit 8a1ef9c

2 files changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ def load_vae_weights(
340340
else:
341341
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
342342

343+
# print(f"Loaded VAE Key: {flax_key}")
344+
345+
print(f"Total VAE keys loaded: {len(flax_state_dict)}")
343346
validate_flax_state_dict(eval_shapes, flax_state_dict)
344347
flax_state_dict = unflatten_dict(flax_state_dict)
345348
del tensors

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def test_load_vae_weights(self):
108108
print("Validating VAE Weights...")
109109
# Filter out dropout/rngs keys from eval_shapes as they are not expected in weights
110110
filtered_eval_shapes = {}
111-
for k, v in eval_shapes.items():
111+
flat_eval_shapes = flatten_dict(eval_shapes)
112+
for k, v in flat_eval_shapes.items():
112113
k_str = [str(x) for x in k]
113114
if "dropout" in k_str or "rngs" in k_str:
114115
continue

0 commit comments

Comments
 (0)