Skip to content

Commit a257f4e

Browse files
committed
fix
1 parent 887eb8b commit a257f4e

5 files changed

Lines changed: 1 addition & 290 deletions

File tree

check_encoder_keys.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

inspect_vae_checkpoint.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

inspect_vae_structure.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

reproduce_vae_mapping.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,21 +339,15 @@ def load_vae_weights(
339339
flax_state_dict[flax_key] = flax_tensor
340340
else:
341341
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
342-
343-
344-
345-
# Filter out non-parameter keys for validation
346342
filtered_eval_shapes = {}
347343
for k, v in flattened_eval.items():
348-
# flax key is a tuple of strings/ints
349344
k_str = [str(x) for x in k]
350345
if "dropout" in k_str or "rngs" in k_str:
351346
continue
352347
filtered_eval_shapes[k] = v
353348

354349
print(f"Total VAE keys loaded: {len(flax_state_dict)}")
355-
356-
# Unflatten to pass to validate_flax_state_dict which expects a pytree
350+
357351
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
358352
flax_state_dict = unflatten_dict(flax_state_dict)
359353
del tensors

0 commit comments

Comments
 (0)