Skip to content

Commit 6d6e227

Browse files
committed
debug_audio_vae
1 parent 2132468 commit 6d6e227

1 file changed

Lines changed: 17 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,23 @@ def load_audio_vae_weights(
565565

566566
flax_key = tuple(flax_key_parts)
567567

568+
# Reverse up_stages indices if present
569+
if "up_stages" in flax_key:
570+
# Find index of 'up_stages'
571+
try:
572+
up_stages_idx = flax_key.index("up_stages")
573+
# The integer index follows "up_stages"
574+
if up_stages_idx + 1 < len(flax_key):
575+
stage_idx = flax_key[up_stages_idx + 1]
576+
if isinstance(stage_idx, int):
577+
# Assuming 3 stages (0, 1, 2)
578+
# Map 0 -> 2, 1 -> 1, 2 -> 0
579+
new_stage_idx = 2 - stage_idx
580+
flax_key_parts[up_stages_idx + 1] = new_stage_idx
581+
flax_key = tuple(flax_key_parts)
582+
except ValueError:
583+
pass
584+
568585
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
569586

570587
# Filter eval shapes to remove rngs/dropout

0 commit comments

Comments
 (0)