Skip to content

Commit 137d41a

Browse files
committed
fix for audi vae weights
1 parent c109dbd commit 137d41a

1 file changed

Lines changed: 59 additions & 74 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 59 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,10 @@ def rename_for_ltx2_audio_vae(key):
490490
# Standard VAE renaming (resblocks -> resnets, ups -> upsamplers)
491491
key = key.replace("resblocks", "resnets")
492492
key = key.replace("ups", "upsamplers")
493-
key = key.replace("conv_shortcut.weight", "conv_shortcut_layer.kernel")
494-
key = key.replace("conv_shortcut.bias", "conv_shortcut_layer.bias")
493+
494+
# conv_shortcut -> conv_shortcut_layer.conv (Causal)
495+
key = key.replace("conv_shortcut.weight", "conv_shortcut_layer.conv.kernel")
496+
key = key.replace("conv_shortcut.bias", "conv_shortcut_layer.conv.bias")
495497

496498
# Handle q, k, v, proj_out in AttnBlock
497499
if "q.weight" in key: key = key.replace("q.weight", "q.kernel")
@@ -503,6 +505,24 @@ def rename_for_ltx2_audio_vae(key):
503505
if key.endswith(".weight") and "conv" in key:
504506
key = key.replace(".weight", ".kernel")
505507

508+
# Inject .conv for CausalConvs
509+
# Layers: conv1, conv2, conv_in, conv_out
510+
# These become conv1.conv.kernel etc.
511+
causal_layers = ["conv1", "conv2", "conv_in", "conv_out"]
512+
for layer in causal_layers:
513+
if f"{layer}.kernel" in key:
514+
key = key.replace(f"{layer}.kernel", f"{layer}.conv.kernel")
515+
if f"{layer}.bias" in key:
516+
key = key.replace(f"{layer}.bias", f"{layer}.conv.bias")
517+
518+
# Special handling for upsample.conv (wrapped) vs downsample.conv (not wrapped)
519+
# upsamplers.0.conv -> upsample.conv.conv
520+
# We do this BEFORE the loop renames upsamplers.0 -> upsample
521+
if "upsamplers" in key and ".conv.kernel" in key:
522+
key = key.replace(".conv.kernel", ".conv.conv.kernel")
523+
if "upsamplers" in key and ".conv.bias" in key:
524+
key = key.replace(".conv.bias", ".conv.conv.bias")
525+
506526
return key
507527

508528

@@ -527,9 +547,6 @@ def load_audio_vae_weights(
527547
key = rename_for_ltx2_audio_vae(pt_key)
528548

529549
# Determine if we need to transpose (Conv weights: OHWI -> HWIO)
530-
# PyTorch Conv2d: (Out, In, H, W) -> Flax: (H, W, In, Out)
531-
# However, for 1x1 convs (like q, k, v), it might be (Out, In, 1, 1) -> (1, 1, In, Out)
532-
533550
should_transpose = False
534551
if key.endswith(".kernel"):
535552
if tensor.ndim == 4:
@@ -540,7 +557,6 @@ def load_audio_vae_weights(
540557

541558
# Handle special keys: latents_mean, latents_std
542559
if "latents_mean" in key:
543-
# PyTorch: [C], Flax: [C] (Buffer)
544560
pass
545561
if "latents_std" in key:
546562
pass
@@ -556,38 +572,20 @@ def load_audio_vae_weights(
556572

557573
flax_key = tuple(flax_key_parts)
558574

559-
# Handle resnet nesting (down_blocks.0.resnets.0...)
560-
# LTX-2 Audio VAE structure in Flax might be slightly different if not using List
561-
# But we used nnx.List in the implementation, so it should match mostly.
562-
# Let's check against random_flax_state_dict if possible or rely on structure.
563-
564-
# Special handling for "mid_block" which in PT might be "mid_block.resnets.0"
565-
# but in our Flax implementation is "mid_block1", "mid_block2"
566-
567575
if "mid_block" in pt_key:
568-
# PT: mid_block.resnets.0 -> Flax: mid_block1
569-
# PT: mid_block.attentions.0 -> Flax: mid_attn
570-
# PT: mid_block.resnets.1 -> Flax: mid_block2
571-
572-
new_flax_key_parts = list(flax_key)
573-
if "resnets" in new_flax_key_parts:
574-
idx = new_flax_key_parts[new_flax_key_parts.index("resnets") + 1]
575-
if idx == 0:
576-
# Replace 'mid_block', 'resnets', 0 with 'mid_block1'
577-
# Warning: This is a bit fragile.
578-
pass
579-
580-
# Actually, let's map explicitly based on known structure
581576
if "mid_block.resnets.0" in pt_key:
582-
# mid_block.resnets.0.conv1.weight -> mid_block1.conv1.kernel
583-
key = key.replace("mid_block.resnets.0", "mid_block1")
577+
flax_key_str = ".".join([str(x) for x in flax_key])
578+
flax_key_str = flax_key_str.replace("mid_block.resnets.0", "mid_block1")
584579
elif "mid_block.resnets.1" in pt_key:
585-
key = key.replace("mid_block.resnets.1", "mid_block2")
580+
flax_key_str = ".".join([str(x) for x in flax_key])
581+
flax_key_str = flax_key_str.replace("mid_block.resnets.1", "mid_block2")
586582
elif "mid_block.attentions.0" in pt_key:
587-
key = key.replace("mid_block.attentions.0", "mid_attn")
583+
flax_key_str = ".".join([str(x) for x in flax_key])
584+
flax_key_str = flax_key_str.replace("mid_block.attentions.0", "mid_attn")
585+
else:
586+
flax_key_str = ".".join([str(x) for x in flax_key])
588587

589-
# Re-split after mid_block renaming
590-
parts = key.split(".")
588+
parts = flax_key_str.split(".")
591589
flax_key_parts = []
592590
for part in parts:
593591
if part.isdigit():
@@ -596,26 +594,19 @@ def load_audio_vae_weights(
596594
flax_key_parts.append(part)
597595
flax_key = tuple(flax_key_parts)
598596

599-
# Handle down_blocks / up_blocks
600-
# PT: down_blocks.0.resnets.0 -> Flax: down_stages.0.blocks.0
601-
# PT: down_blocks.0.attentions.0 -> Flax: down_stages.0.attentions.0
602-
# PT: down_blocks.0.downsamplers.0 -> Flax: down_stages.0.downsample
603-
604597
if "down_blocks" in key:
605-
# down_blocks.0.resnets.0 -> down_stages.0.blocks.0
606-
if "resnets" in key:
607-
key = key.replace("down_blocks", "down_stages")
608-
key = key.replace("resnets", "blocks")
609-
elif "attentions" in key:
610-
key = key.replace("down_blocks", "down_stages")
611-
key = key.replace("attentions", "attns")
612-
elif "downsamplers" in key:
613-
key = key.replace("down_blocks", "down_stages")
614-
# downsamplers.0 -> downsample (since we have one downsample per stage)
615-
key = key.replace("downsamplers.0", "downsample")
598+
key_str = ".".join([str(x) for x in flax_key])
599+
if "resnets" in key_str:
600+
key_str = key_str.replace("down_blocks", "down_stages")
601+
key_str = key_str.replace("resnets", "blocks")
602+
elif "attentions" in key_str:
603+
key_str = key_str.replace("down_blocks", "down_stages")
604+
key_str = key_str.replace("attentions", "attns")
605+
elif "downsamplers" in key_str:
606+
key_str = key_str.replace("down_blocks", "down_stages")
607+
key_str = key_str.replace("downsamplers.0", "downsample")
616608

617-
# Re-split
618-
parts = key.split(".")
609+
parts = key_str.split(".")
619610
flax_key_parts = []
620611
for part in parts:
621612
if part.isdigit():
@@ -625,29 +616,18 @@ def load_audio_vae_weights(
625616
flax_key = tuple(flax_key_parts)
626617

627618
if "up_blocks" in key:
628-
# up_blocks.0.resnets.0 -> up_stages.0.blocks.0
629-
# Note: PT up_blocks are usually reversed compared to simple iteration, but
630-
# in Diffusers they correspond to levels.
631-
# Flax implementation: `up_stages` list iterates reversed(range(len(ch_mult)))
632-
# so up_stages[0] corresponds to the deepest resolution?
633-
# LTX-2 Audio VAE implementation:
634-
# for level in reversed(range(len(self.ch_mult))): ... self.up_stages.append(...)
635-
# So up_stages[0] is the first upsample stage (lowest res -> higher res).
636-
# Diffusers `up_blocks` usually go 0, 1, 2...
637-
# So it should be a direct mapping if existing logic holds.
638-
639-
if "resnets" in key:
640-
key = key.replace("up_blocks", "up_stages")
641-
key = key.replace("resnets", "blocks")
642-
elif "attentions" in key:
643-
key = key.replace("up_blocks", "up_stages")
644-
key = key.replace("attentions", "attns")
645-
elif "upsamplers" in key:
646-
key = key.replace("up_blocks", "up_stages")
647-
key = key.replace("upsamplers.0", "upsample")
619+
key_str = ".".join([str(x) for x in flax_key])
620+
if "resnets" in key_str:
621+
key_str = key_str.replace("up_blocks", "up_stages")
622+
key_str = key_str.replace("resnets", "blocks")
623+
elif "attentions" in key_str:
624+
key_str = key_str.replace("up_blocks", "up_stages")
625+
key_str = key_str.replace("attentions", "attns")
626+
elif "upsamplers" in key_str:
627+
key_str = key_str.replace("up_blocks", "up_stages")
628+
key_str = key_str.replace("upsamplers.0", "upsample")
648629

649-
# Re-split
650-
parts = key.split(".")
630+
parts = key_str.split(".")
651631
flax_key_parts = []
652632
for part in parts:
653633
if part.isdigit():
@@ -662,7 +642,12 @@ def load_audio_vae_weights(
662642
filtered_eval_shapes = {}
663643
for k, v in flattened_eval.items():
664644
k_str = [str(x) for x in k]
665-
if "dropout" in k_str or "rngs" in k_str:
645+
is_stat = False
646+
for ks in k_str:
647+
if "dropout" in ks or "rngs" in ks:
648+
is_stat = True
649+
break
650+
if is_stat:
666651
continue
667652
filtered_eval_shapes[k] = v
668653

0 commit comments

Comments
 (0)