Skip to content

Commit c109dbd

Browse files
committed
Audio VAE weights
1 parent 753ab0f commit c109dbd

2 files changed

Lines changed: 226 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,186 @@ def load_connector_weights(
485485
jax.clear_caches()
486486
validate_flax_state_dict(eval_shapes, flax_state_dict)
487487
return unflatten_dict(flax_state_dict)
488+
489+
def rename_for_ltx2_audio_vae(key):
490+
# Standard VAE renaming (resblocks -> resnets, ups -> upsamplers)
491+
key = key.replace("resblocks", "resnets")
492+
key = key.replace("ups", "upsamplers")
493+
key = key.replace("conv_shortcut.weight", "conv_shortcut_layer.kernel")
494+
key = key.replace("conv_shortcut.bias", "conv_shortcut_layer.bias")
495+
496+
# Handle q, k, v, proj_out in AttnBlock
497+
if "q.weight" in key: key = key.replace("q.weight", "q.kernel")
498+
if "k.weight" in key: key = key.replace("k.weight", "k.kernel")
499+
if "v.weight" in key: key = key.replace("v.weight", "v.kernel")
500+
if "proj_out.weight" in key: key = key.replace("proj_out.weight", "proj_out.kernel")
501+
502+
# Handle conv.weight -> conv.kernel
503+
if key.endswith(".weight") and "conv" in key:
504+
key = key.replace(".weight", ".kernel")
505+
506+
return key
507+
508+
509+
def load_audio_vae_weights(
510+
pretrained_model_name_or_path: str,
511+
eval_shapes: dict,
512+
device: str,
513+
hf_download: bool = True,
514+
subfolder: str = "audio_vae"
515+
):
516+
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device)
517+
flax_state_dict = {}
518+
cpu = jax.local_devices(backend="cpu")[0]
519+
520+
flattened_eval = flatten_dict(eval_shapes)
521+
random_flax_state_dict = {}
522+
for key in flattened_eval:
523+
string_tuple = tuple([str(item) for item in key])
524+
random_flax_state_dict[string_tuple] = flattened_eval[key]
525+
526+
for pt_key, tensor in tensors.items():
527+
key = rename_for_ltx2_audio_vae(pt_key)
528+
529+
# Determine if we need to transpose (Conv weights: OHWI -> HWIO)
530+
# PyTorch Conv2d: (Out, In, H, W) -> Flax: (H, W, In, Out)
531+
# However, for 1x1 convs (like q, k, v), it might be (Out, In, 1, 1) -> (1, 1, In, Out)
532+
533+
should_transpose = False
534+
if key.endswith(".kernel"):
535+
if tensor.ndim == 4:
536+
should_transpose = True
537+
538+
if should_transpose:
539+
tensor = tensor.transpose(2, 3, 1, 0)
540+
541+
# Handle special keys: latents_mean, latents_std
542+
if "latents_mean" in key:
543+
# PyTorch: [C], Flax: [C] (Buffer)
544+
pass
545+
if "latents_std" in key:
546+
pass
547+
548+
# Convert key to tuple
549+
parts = key.split(".")
550+
flax_key_parts = []
551+
for part in parts:
552+
if part.isdigit():
553+
flax_key_parts.append(int(part))
554+
else:
555+
flax_key_parts.append(part)
556+
557+
flax_key = tuple(flax_key_parts)
558+
559+
# Handle resnet nesting (down_blocks.0.resnets.0...)
560+
# LTX-2 Audio VAE structure in Flax might be slightly different if not using List
561+
# But we used nnx.List in the implementation, so it should match mostly.
562+
# Let's check against random_flax_state_dict if possible or rely on structure.
563+
564+
# Special handling for "mid_block" which in PT might be "mid_block.resnets.0"
565+
# but in our Flax implementation is "mid_block1", "mid_block2"
566+
567+
if "mid_block" in pt_key:
568+
# PT: mid_block.resnets.0 -> Flax: mid_block1
569+
# PT: mid_block.attentions.0 -> Flax: mid_attn
570+
# PT: mid_block.resnets.1 -> Flax: mid_block2
571+
572+
new_flax_key_parts = list(flax_key)
573+
if "resnets" in new_flax_key_parts:
574+
idx = new_flax_key_parts[new_flax_key_parts.index("resnets") + 1]
575+
if idx == 0:
576+
# Replace 'mid_block', 'resnets', 0 with 'mid_block1'
577+
# Warning: This is a bit fragile.
578+
pass
579+
580+
# Actually, let's map explicitly based on known structure
581+
if "mid_block.resnets.0" in pt_key:
582+
# mid_block.resnets.0.conv1.weight -> mid_block1.conv1.kernel
583+
key = key.replace("mid_block.resnets.0", "mid_block1")
584+
elif "mid_block.resnets.1" in pt_key:
585+
key = key.replace("mid_block.resnets.1", "mid_block2")
586+
elif "mid_block.attentions.0" in pt_key:
587+
key = key.replace("mid_block.attentions.0", "mid_attn")
588+
589+
# Re-split after mid_block renaming
590+
parts = key.split(".")
591+
flax_key_parts = []
592+
for part in parts:
593+
if part.isdigit():
594+
flax_key_parts.append(int(part))
595+
else:
596+
flax_key_parts.append(part)
597+
flax_key = tuple(flax_key_parts)
598+
599+
# Handle down_blocks / up_blocks
600+
# PT: down_blocks.0.resnets.0 -> Flax: down_stages.0.blocks.0
601+
# PT: down_blocks.0.attentions.0 -> Flax: down_stages.0.attentions.0
602+
# PT: down_blocks.0.downsamplers.0 -> Flax: down_stages.0.downsample
603+
604+
if "down_blocks" in key:
605+
# down_blocks.0.resnets.0 -> down_stages.0.blocks.0
606+
if "resnets" in key:
607+
key = key.replace("down_blocks", "down_stages")
608+
key = key.replace("resnets", "blocks")
609+
elif "attentions" in key:
610+
key = key.replace("down_blocks", "down_stages")
611+
key = key.replace("attentions", "attns")
612+
elif "downsamplers" in key:
613+
key = key.replace("down_blocks", "down_stages")
614+
# downsamplers.0 -> downsample (since we have one downsample per stage)
615+
key = key.replace("downsamplers.0", "downsample")
616+
617+
# Re-split
618+
parts = key.split(".")
619+
flax_key_parts = []
620+
for part in parts:
621+
if part.isdigit():
622+
flax_key_parts.append(int(part))
623+
else:
624+
flax_key_parts.append(part)
625+
flax_key = tuple(flax_key_parts)
626+
627+
if "up_blocks" in key:
628+
# up_blocks.0.resnets.0 -> up_stages.0.blocks.0
629+
# Note: PT up_blocks are usually reversed compared to simple iteration, but
630+
# in Diffusers they correspond to levels.
631+
# Flax implementation: `up_stages` list iterates reversed(range(len(ch_mult)))
632+
# so up_stages[0] corresponds to the deepest resolution?
633+
# LTX-2 Audio VAE implementation:
634+
# for level in reversed(range(len(self.ch_mult))): ... self.up_stages.append(...)
635+
# So up_stages[0] is the first upsample stage (lowest res -> higher res).
636+
# Diffusers `up_blocks` usually go 0, 1, 2...
637+
# So it should be a direct mapping if existing logic holds.
638+
639+
if "resnets" in key:
640+
key = key.replace("up_blocks", "up_stages")
641+
key = key.replace("resnets", "blocks")
642+
elif "attentions" in key:
643+
key = key.replace("up_blocks", "up_stages")
644+
key = key.replace("attentions", "attns")
645+
elif "upsamplers" in key:
646+
key = key.replace("up_blocks", "up_stages")
647+
key = key.replace("upsamplers.0", "upsample")
648+
649+
# Re-split
650+
parts = key.split(".")
651+
flax_key_parts = []
652+
for part in parts:
653+
if part.isdigit():
654+
flax_key_parts.append(int(part))
655+
else:
656+
flax_key_parts.append(part)
657+
flax_key = tuple(flax_key_parts)
658+
659+
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
660+
661+
# Filter eval shapes to remove rngs/dropout
662+
filtered_eval_shapes = {}
663+
for k, v in flattened_eval.items():
664+
k_str = [str(x) for x in k]
665+
if "dropout" in k_str or "rngs" in k_str:
666+
continue
667+
filtered_eval_shapes[k] = v
668+
669+
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
670+
return unflatten_dict(flax_state_dict)

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,5 +183,48 @@ def test_load_connector_weights(self):
183183
validate_flax_state_dict(eval_shapes, flatten_dict(loaded_weights))
184184
print("Connector Weights Validated Successfully!")
185185

186+
def test_load_audio_vae_weights(self):
187+
from maxdiffusion.models.ltx2.audio_vae import FlaxAutoencoderKLLTX2Audio
188+
from maxdiffusion.models.ltx2.ltx2_utils import load_audio_vae_weights
189+
190+
pretrained_model_name_or_path = "Lightricks/LTX-2"
191+
192+
# Audio VAE Config from user request
193+
config = {
194+
"base_channels": 128,
195+
"ch_mult": (1, 2, 4),
196+
"double_z": True,
197+
"dropout": 0.0,
198+
"in_channels": 2,
199+
"latent_channels": 8,
200+
"mel_bins": 64,
201+
"mel_hop_length": 160,
202+
"mid_block_add_attention": False,
203+
"norm_type": "pixel",
204+
"num_res_blocks": 2,
205+
"output_channels": 2,
206+
"resolution": 256,
207+
"sample_rate": 16000,
208+
"rngs": nnx.Rngs(0)
209+
}
210+
211+
with jax.default_device(jax.devices("cpu")[0]):
212+
model = FlaxAutoencoderKLLTX2Audio(**config)
213+
214+
state = nnx.state(model)
215+
eval_shapes = state.to_pure_dict()
216+
217+
print("Loading Audio VAE Weights...")
218+
loaded_weights = load_audio_vae_weights(
219+
pretrained_model_name_or_path=pretrained_model_name_or_path,
220+
eval_shapes=eval_shapes,
221+
device=self.device,
222+
hf_download=True
223+
)
224+
225+
print("Validating Audio VAE Weights...")
226+
validate_flax_state_dict(eval_shapes, flatten_dict(loaded_weights))
227+
print("Audio VAE Weights Validated Successfully!")
228+
186229
if __name__ == "__main__":
187230
unittest.main()

0 commit comments

Comments
 (0)