Skip to content

Commit db4acec

Browse files
committed
debug_audio_vae
1 parent a794925 commit db4acec

1 file changed

Lines changed: 31 additions & 102 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 31 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -487,42 +487,37 @@ def load_connector_weights(
487487
return unflatten_dict(flax_state_dict)
488488

489489
def rename_for_ltx2_audio_vae(key):
490-
# Standard VAE renaming (resblocks -> resnets, ups -> upsamplers)
491-
key = key.replace("resblocks", "resnets")
492-
key = key.replace("ups", "upsamplers")
490+
# LTX-2 Audio VAE specific renaming
491+
492+
# 1. Common renames
493+
if key.endswith(".weight"):
494+
key = key.replace(".weight", ".kernel")
495+
496+
# 2. Structure renames
497+
# mid.block_1 -> mid_block1
498+
key = key.replace("mid.block_1", "mid_block1")
499+
key = key.replace("mid.block_2", "mid_block2")
500+
key = key.replace("mid.attn_1", "mid_attn") # Assumption, but safe
493501

494-
# conv_shortcut -> conv_shortcut_layer.conv (Causal)
495-
key = key.replace("conv_shortcut.weight", "conv_shortcut_layer.conv.kernel")
496-
key = key.replace("conv_shortcut.bias", "conv_shortcut_layer.conv.bias")
502+
# up.0 -> up_stages.0
503+
key = key.replace("up.", "up_stages.")
504+
# down.0 -> down_stages.0
505+
key = key.replace("down.", "down_stages.")
497506

498-
# Handle q, k, v, proj_out in AttnBlock
499-
if "q.weight" in key: key = key.replace("q.weight", "q.kernel")
500-
if "k.weight" in key: key = key.replace("k.weight", "k.kernel")
501-
if "v.weight" in key: key = key.replace("v.weight", "v.kernel")
502-
if "proj_out.weight" in key: key = key.replace("proj_out.weight", "proj_out.kernel")
507+
# block.0 -> blocks.0
508+
key = key.replace("block.", "blocks.")
503509

504-
# Handle conv.weight -> conv.kernel
505-
if key.endswith(".weight") and "conv" in key:
506-
key = key.replace(".weight", ".kernel")
510+
# nin_shortcut -> conv_shortcut_layer
511+
key = key.replace("nin_shortcut", "conv_shortcut_layer")
512+
513+
# In case upsample/downsample keys are just 'upsample' / 'downsample'
514+
# Check for CausalConv wrapping in upsample
515+
# If PT is 'upsample.conv.kernel' but Flax needs 'upsample.conv.conv.kernel'
516+
if "upsample.conv.kernel" in key:
517+
key = key.replace("upsample.conv.kernel", "upsample.conv.conv.kernel")
518+
if "upsample.conv.bias" in key:
519+
key = key.replace("upsample.conv.bias", "upsample.conv.conv.bias")
507520

508-
# Inject .conv for CausalConvs
509-
# Layers: conv1, conv2, conv_in, conv_out
510-
# These become conv1.conv.kernel etc.
511-
causal_layers = ["conv1", "conv2", "conv_in", "conv_out"]
512-
for layer in causal_layers:
513-
if f"{layer}.kernel" in key:
514-
key = key.replace(f"{layer}.kernel", f"{layer}.conv.kernel")
515-
if f"{layer}.bias" in key:
516-
key = key.replace(f"{layer}.bias", f"{layer}.conv.bias")
517-
518-
# Special handling for upsample.conv (wrapped) vs downsample.conv (not wrapped)
519-
# upsamplers.0.conv -> upsample.conv.conv
520-
# We do this BEFORE the loop renames upsamplers.0 -> upsample
521-
if "upsamplers" in key and ".conv.kernel" in key:
522-
key = key.replace(".conv.kernel", ".conv.conv.kernel")
523-
if "upsamplers" in key and ".conv.bias" in key:
524-
key = key.replace(".conv.bias", ".conv.conv.bias")
525-
526521
return key
527522

528523

@@ -547,6 +542,10 @@ def load_audio_vae_weights(
547542
key = rename_for_ltx2_audio_vae(pt_key)
548543

549544
# Determine if we need to transpose (Conv weights: OHWI -> HWIO)
545+
# Note: 1x1 convs might also be 4D.
546+
# Standard Flax Conv: (H, W, I, O)
547+
# Standard PyTorch Conv: (O, I, H, W)
548+
550549
should_transpose = False
551550
if key.endswith(".kernel"):
552551
if tensor.ndim == 4:
@@ -555,12 +554,6 @@ def load_audio_vae_weights(
555554
if should_transpose:
556555
tensor = tensor.transpose(2, 3, 1, 0)
557556

558-
# Handle special keys: latents_mean, latents_std
559-
if "latents_mean" in key:
560-
pass
561-
if "latents_std" in key:
562-
pass
563-
564557
# Convert key to tuple
565558
parts = key.split(".")
566559
flax_key_parts = []
@@ -571,70 +564,6 @@ def load_audio_vae_weights(
571564
flax_key_parts.append(part)
572565

573566
flax_key = tuple(flax_key_parts)
574-
575-
if "mid_block" in pt_key:
576-
if "mid_block.resnets.0" in pt_key:
577-
flax_key_str = ".".join([str(x) for x in flax_key])
578-
flax_key_str = flax_key_str.replace("mid_block.resnets.0", "mid_block1")
579-
elif "mid_block.resnets.1" in pt_key:
580-
flax_key_str = ".".join([str(x) for x in flax_key])
581-
flax_key_str = flax_key_str.replace("mid_block.resnets.1", "mid_block2")
582-
elif "mid_block.attentions.0" in pt_key:
583-
flax_key_str = ".".join([str(x) for x in flax_key])
584-
flax_key_str = flax_key_str.replace("mid_block.attentions.0", "mid_attn")
585-
else:
586-
flax_key_str = ".".join([str(x) for x in flax_key])
587-
588-
parts = flax_key_str.split(".")
589-
flax_key_parts = []
590-
for part in parts:
591-
if part.isdigit():
592-
flax_key_parts.append(int(part))
593-
else:
594-
flax_key_parts.append(part)
595-
flax_key = tuple(flax_key_parts)
596-
597-
if "down_blocks" in key:
598-
key_str = ".".join([str(x) for x in flax_key])
599-
if "resnets" in key_str:
600-
key_str = key_str.replace("down_blocks", "down_stages")
601-
key_str = key_str.replace("resnets", "blocks")
602-
elif "attentions" in key_str:
603-
key_str = key_str.replace("down_blocks", "down_stages")
604-
key_str = key_str.replace("attentions", "attns")
605-
elif "downsamplers" in key_str:
606-
key_str = key_str.replace("down_blocks", "down_stages")
607-
key_str = key_str.replace("downsamplers.0", "downsample")
608-
609-
parts = key_str.split(".")
610-
flax_key_parts = []
611-
for part in parts:
612-
if part.isdigit():
613-
flax_key_parts.append(int(part))
614-
else:
615-
flax_key_parts.append(part)
616-
flax_key = tuple(flax_key_parts)
617-
618-
if "up_blocks" in key:
619-
key_str = ".".join([str(x) for x in flax_key])
620-
if "resnets" in key_str:
621-
key_str = key_str.replace("up_blocks", "up_stages")
622-
key_str = key_str.replace("resnets", "blocks")
623-
elif "attentions" in key_str:
624-
key_str = key_str.replace("up_blocks", "up_stages")
625-
key_str = key_str.replace("attentions", "attns")
626-
elif "upsamplers" in key_str:
627-
key_str = key_str.replace("up_blocks", "up_stages")
628-
key_str = key_str.replace("upsamplers.0", "upsample")
629-
630-
parts = key_str.split(".")
631-
flax_key_parts = []
632-
for part in parts:
633-
if part.isdigit():
634-
flax_key_parts.append(int(part))
635-
else:
636-
flax_key_parts.append(part)
637-
flax_key = tuple(flax_key_parts)
638567

639568
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
640569

0 commit comments

Comments
 (0)