Skip to content

Commit 54bd279

Browse files
committed
cleanup
1 parent 42a2c64 commit 54bd279

3 files changed

Lines changed: 4 additions & 314 deletions

File tree

debug_audio_vae.py

Lines changed: 0 additions & 279 deletions
This file was deleted.

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -487,37 +487,25 @@ def load_connector_weights(
487487
return unflatten_dict(flax_state_dict)
488488

489489
def rename_for_ltx2_audio_vae(key):
490-
# LTX-2 Audio VAE specific renaming
491-
492-
# 1. Common renames
493490
if key.endswith(".weight"):
494491
key = key.replace(".weight", ".kernel")
495-
496-
# 2. Structure renames
497-
# mid.block_1 -> mid_block1
492+
498493
key = key.replace("mid.block_1", "mid_block1")
499494
key = key.replace("mid.block_2", "mid_block2")
500-
key = key.replace("mid.attn_1", "mid_attn") # Assumption, but safe
495+
key = key.replace("mid.attn_1", "mid_attn")
501496

502-
# up.0 -> up_stages.0
503497
key = key.replace("up.", "up_stages.")
504-
# down.0 -> down_stages.0
505498
key = key.replace("down.", "down_stages.")
506499

507-
# block.0 -> blocks.0
508500
key = key.replace("block.", "blocks.")
509501

510-
# nin_shortcut -> conv_shortcut_layer
511502
key = key.replace("nin_shortcut", "conv_shortcut_layer")
512503

513-
# In case upsample/downsample keys are just 'upsample' / 'downsample'
514-
# Check for CausalConv wrapping in upsample
515-
# If PT is 'upsample.conv.kernel' but Flax needs 'upsample.conv.conv.kernel'
516504
if "upsample.conv.kernel" in key:
517505
key = key.replace("upsample.conv.kernel", "upsample.conv.conv.kernel")
518506
if "upsample.conv.bias" in key:
519507
key = key.replace("upsample.conv.bias", "upsample.conv.conv.bias")
520-
508+
521509
return key
522510

523511

@@ -541,20 +529,13 @@ def load_audio_vae_weights(
541529
for pt_key, tensor in tensors.items():
542530
key = rename_for_ltx2_audio_vae(pt_key)
543531

544-
# Determine if we need to transpose (Conv weights: OHWI -> HWIO)
545-
# Note: 1x1 convs might also be 4D.
546-
# Standard Flax Conv: (H, W, I, O)
547-
# Standard PyTorch Conv: (O, I, H, W)
548-
549532
should_transpose = False
550533
if key.endswith(".kernel"):
551534
if tensor.ndim == 4:
552535
should_transpose = True
553536

554537
if should_transpose:
555538
tensor = tensor.transpose(2, 3, 1, 0)
556-
557-
# Convert key to tuple
558539
parts = key.split(".")
559540
flax_key_parts = []
560541
for part in parts:
@@ -565,29 +546,19 @@ def load_audio_vae_weights(
565546

566547
flax_key = tuple(flax_key_parts)
567548

568-
# Reverse up_stages indices if present
569549
if "up_stages" in flax_key:
570-
# Find index of 'up_stages'
571550
try:
572551
up_stages_idx = flax_key.index("up_stages")
573-
# The integer index follows "up_stages"
574552
if up_stages_idx + 1 < len(flax_key):
575553
stage_idx = flax_key[up_stages_idx + 1]
576554
if isinstance(stage_idx, int):
577-
# Assuming 3 stages (0, 1, 2)
578-
# Map 0 -> 2, 1 -> 1, 2 -> 0
579555
new_stage_idx = 2 - stage_idx
580-
if "upsample" in flax_key:
581-
# print(f"DEBUG REVERSAL: {flax_key} -> stage_idx={stage_idx} -> new={new_stage_idx}")
582-
pass
583556
flax_key_parts[up_stages_idx + 1] = new_stage_idx
584557
flax_key = tuple(flax_key_parts)
585558
except ValueError:
586559
pass
587560

588561
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
589-
590-
# Filter eval shapes to remove rngs/dropout
591562
filtered_eval_shapes = {}
592563
for k, v in flattened_eval.items():
593564
k_str = [str(x) for x in k]

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,7 @@ def test_load_audio_vae_weights(self):
188188
from maxdiffusion.models.ltx2.ltx2_utils import load_audio_vae_weights
189189

190190
pretrained_model_name_or_path = "Lightricks/LTX-2"
191-
192-
# Audio VAE Config from user request
191+
193192
config = {
194193
"base_channels": 128,
195194
"ch_mult": (1, 2, 4),
@@ -223,7 +222,6 @@ def test_load_audio_vae_weights(self):
223222
)
224223

225224
print("Validating Audio VAE Weights...")
226-
# Filter eval_shapes for validation as load_audio_vae_weights returns filtered weights
227225
filtered_eval_shapes = {}
228226
flat_eval = flatten_dict(eval_shapes)
229227
for k, v in flat_eval.items():

0 commit comments

Comments
 (0)