Skip to content

Commit 5b6fe91

Browse files
committed
debug_audio_vae
1 parent 598f702 commit 5b6fe91

2 files changed

Lines changed: 3 additions & 1 deletion

File tree

debug_audio_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def flatten(d, parent_key=()):
236236

237237
# Also check if validation function itself is behaving as expected
238238
from flax.traverse_util import unflatten_dict, flatten_dict
239-
from maxdiffusion.modeling_flax_pytorch_utils import validate_flax_state_dict
239+
from maxdiffusion.models.modeling_flax_pytorch_utils import validate_flax_state_dict
240240

241241
# Construct a dummy flax_state_dict with only the keys we found
242242
# We need to map our final_keys back to a dict

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,8 @@ def load_audio_vae_weights(
577577
# Assuming 3 stages (0, 1, 2)
578578
# Map 0 -> 2, 1 -> 1, 2 -> 0
579579
new_stage_idx = 2 - stage_idx
580+
if "upsample" in flax_key:
581+
print(f"DEBUG REVERSAL: {flax_key} -> stage_idx={stage_idx} -> new={new_stage_idx}")
580582
flax_key_parts[up_stages_idx + 1] = new_stage_idx
581583
flax_key = tuple(flax_key_parts)
582584
except ValueError:

0 commit comments

Comments
 (0)