@@ -490,8 +490,10 @@ def rename_for_ltx2_audio_vae(key):
490490 # Standard VAE renaming (resblocks -> resnets, ups -> upsamplers)
491491 key = key .replace ("resblocks" , "resnets" )
492492 key = key .replace ("ups" , "upsamplers" )
493- key = key .replace ("conv_shortcut.weight" , "conv_shortcut_layer.kernel" )
494- key = key .replace ("conv_shortcut.bias" , "conv_shortcut_layer.bias" )
493+
494+ # conv_shortcut -> conv_shortcut_layer.conv (Causal)
495+ key = key .replace ("conv_shortcut.weight" , "conv_shortcut_layer.conv.kernel" )
496+ key = key .replace ("conv_shortcut.bias" , "conv_shortcut_layer.conv.bias" )
495497
496498 # Handle q, k, v, proj_out in AttnBlock
497499 if "q.weight" in key : key = key .replace ("q.weight" , "q.kernel" )
@@ -503,6 +505,24 @@ def rename_for_ltx2_audio_vae(key):
503505 if key .endswith (".weight" ) and "conv" in key :
504506 key = key .replace (".weight" , ".kernel" )
505507
508+ # Inject .conv for CausalConvs
509+ # Layers: conv1, conv2, conv_in, conv_out
510+ # These become conv1.conv.kernel etc.
511+ causal_layers = ["conv1" , "conv2" , "conv_in" , "conv_out" ]
512+ for layer in causal_layers :
513+ if f"{ layer } .kernel" in key :
514+ key = key .replace (f"{ layer } .kernel" , f"{ layer } .conv.kernel" )
515+ if f"{ layer } .bias" in key :
516+ key = key .replace (f"{ layer } .bias" , f"{ layer } .conv.bias" )
517+
518+ # Special handling for upsample.conv (wrapped) vs downsample.conv (not wrapped)
519+ # upsamplers.0.conv -> upsample.conv.conv
520+ # We do this BEFORE the loop renames upsamplers.0 -> upsample
521+ if "upsamplers" in key and ".conv.kernel" in key :
522+ key = key .replace (".conv.kernel" , ".conv.conv.kernel" )
523+ if "upsamplers" in key and ".conv.bias" in key :
524+ key = key .replace (".conv.bias" , ".conv.conv.bias" )
525+
506526 return key
507527
508528
@@ -527,9 +547,6 @@ def load_audio_vae_weights(
527547 key = rename_for_ltx2_audio_vae (pt_key )
528548
529549 # Determine if we need to transpose (Conv weights: OHWI -> HWIO)
530- # PyTorch Conv2d: (Out, In, H, W) -> Flax: (H, W, In, Out)
531- # However, for 1x1 convs (like q, k, v), it might be (Out, In, 1, 1) -> (1, 1, In, Out)
532-
533550 should_transpose = False
534551 if key .endswith (".kernel" ):
535552 if tensor .ndim == 4 :
@@ -540,7 +557,6 @@ def load_audio_vae_weights(
540557
541558 # Handle special keys: latents_mean, latents_std
542559 if "latents_mean" in key :
543- # PyTorch: [C], Flax: [C] (Buffer)
544560 pass
545561 if "latents_std" in key :
546562 pass
@@ -556,38 +572,20 @@ def load_audio_vae_weights(
556572
557573 flax_key = tuple (flax_key_parts )
558574
559- # Handle resnet nesting (down_blocks.0.resnets.0...)
560- # LTX-2 Audio VAE structure in Flax might be slightly different if not using List
561- # But we used nnx.List in the implementation, so it should match mostly.
562- # Let's check against random_flax_state_dict if possible or rely on structure.
563-
564- # Special handling for "mid_block" which in PT might be "mid_block.resnets.0"
565- # but in our Flax implementation is "mid_block1", "mid_block2"
566-
567575 if "mid_block" in pt_key :
568- # PT: mid_block.resnets.0 -> Flax: mid_block1
569- # PT: mid_block.attentions.0 -> Flax: mid_attn
570- # PT: mid_block.resnets.1 -> Flax: mid_block2
571-
572- new_flax_key_parts = list (flax_key )
573- if "resnets" in new_flax_key_parts :
574- idx = new_flax_key_parts [new_flax_key_parts .index ("resnets" ) + 1 ]
575- if idx == 0 :
576- # Replace 'mid_block', 'resnets', 0 with 'mid_block1'
577- # Warning: This is a bit fragile.
578- pass
579-
580- # Actually, let's map explicitly based on known structure
581576 if "mid_block.resnets.0" in pt_key :
582- # mid_block.resnets.0.conv1.weight -> mid_block1.conv1.kernel
583- key = key .replace ("mid_block.resnets.0" , "mid_block1" )
577+ flax_key_str = "." . join ([ str ( x ) for x in flax_key ])
578+ flax_key_str = flax_key_str .replace ("mid_block.resnets.0" , "mid_block1" )
584579 elif "mid_block.resnets.1" in pt_key :
585- key = key .replace ("mid_block.resnets.1" , "mid_block2" )
580+ flax_key_str = "." .join ([str (x ) for x in flax_key ])
581+ flax_key_str = flax_key_str .replace ("mid_block.resnets.1" , "mid_block2" )
586582 elif "mid_block.attentions.0" in pt_key :
587- key = key .replace ("mid_block.attentions.0" , "mid_attn" )
583+ flax_key_str = "." .join ([str (x ) for x in flax_key ])
584+ flax_key_str = flax_key_str .replace ("mid_block.attentions.0" , "mid_attn" )
585+ else :
586+ flax_key_str = "." .join ([str (x ) for x in flax_key ])
588587
589- # Re-split after mid_block renaming
590- parts = key .split ("." )
588+ parts = flax_key_str .split ("." )
591589 flax_key_parts = []
592590 for part in parts :
593591 if part .isdigit ():
@@ -596,26 +594,19 @@ def load_audio_vae_weights(
596594 flax_key_parts .append (part )
597595 flax_key = tuple (flax_key_parts )
598596
599- # Handle down_blocks / up_blocks
600- # PT: down_blocks.0.resnets.0 -> Flax: down_stages.0.blocks.0
601- # PT: down_blocks.0.attentions.0 -> Flax: down_stages.0.attentions.0
602- # PT: down_blocks.0.downsamplers.0 -> Flax: down_stages.0.downsample
603-
604597 if "down_blocks" in key :
605- # down_blocks.0.resnets.0 -> down_stages.0.blocks.0
606- if "resnets" in key :
607- key = key .replace ("down_blocks" , "down_stages" )
608- key = key .replace ("resnets" , "blocks" )
609- elif "attentions" in key :
610- key = key .replace ("down_blocks" , "down_stages" )
611- key = key .replace ("attentions" , "attns" )
612- elif "downsamplers" in key :
613- key = key .replace ("down_blocks" , "down_stages" )
614- # downsamplers.0 -> downsample (since we have one downsample per stage)
615- key = key .replace ("downsamplers.0" , "downsample" )
598+ key_str = "." .join ([str (x ) for x in flax_key ])
599+ if "resnets" in key_str :
600+ key_str = key_str .replace ("down_blocks" , "down_stages" )
601+ key_str = key_str .replace ("resnets" , "blocks" )
602+ elif "attentions" in key_str :
603+ key_str = key_str .replace ("down_blocks" , "down_stages" )
604+ key_str = key_str .replace ("attentions" , "attns" )
605+ elif "downsamplers" in key_str :
606+ key_str = key_str .replace ("down_blocks" , "down_stages" )
607+ key_str = key_str .replace ("downsamplers.0" , "downsample" )
616608
617- # Re-split
618- parts = key .split ("." )
609+ parts = key_str .split ("." )
619610 flax_key_parts = []
620611 for part in parts :
621612 if part .isdigit ():
@@ -625,29 +616,18 @@ def load_audio_vae_weights(
625616 flax_key = tuple (flax_key_parts )
626617
627618 if "up_blocks" in key :
628- # up_blocks.0.resnets.0 -> up_stages.0.blocks.0
629- # Note: PT up_blocks are usually reversed compared to simple iteration, but
630- # in Diffusers they correspond to levels.
631- # Flax implementation: `up_stages` list iterates reversed(range(len(ch_mult)))
632- # so up_stages[0] corresponds to the deepest resolution?
633- # LTX-2 Audio VAE implementation:
634- # for level in reversed(range(len(self.ch_mult))): ... self.up_stages.append(...)
635- # So up_stages[0] is the first upsample stage (lowest res -> higher res).
636- # Diffusers `up_blocks` usually go 0, 1, 2...
637- # So it should be a direct mapping if existing logic holds.
638-
639- if "resnets" in key :
640- key = key .replace ("up_blocks" , "up_stages" )
641- key = key .replace ("resnets" , "blocks" )
642- elif "attentions" in key :
643- key = key .replace ("up_blocks" , "up_stages" )
644- key = key .replace ("attentions" , "attns" )
645- elif "upsamplers" in key :
646- key = key .replace ("up_blocks" , "up_stages" )
647- key = key .replace ("upsamplers.0" , "upsample" )
619+ key_str = "." .join ([str (x ) for x in flax_key ])
620+ if "resnets" in key_str :
621+ key_str = key_str .replace ("up_blocks" , "up_stages" )
622+ key_str = key_str .replace ("resnets" , "blocks" )
623+ elif "attentions" in key_str :
624+ key_str = key_str .replace ("up_blocks" , "up_stages" )
625+ key_str = key_str .replace ("attentions" , "attns" )
626+ elif "upsamplers" in key_str :
627+ key_str = key_str .replace ("up_blocks" , "up_stages" )
628+ key_str = key_str .replace ("upsamplers.0" , "upsample" )
648629
649- # Re-split
650- parts = key .split ("." )
630+ parts = key_str .split ("." )
651631 flax_key_parts = []
652632 for part in parts :
653633 if part .isdigit ():
@@ -662,7 +642,12 @@ def load_audio_vae_weights(
662642 filtered_eval_shapes = {}
663643 for k , v in flattened_eval .items ():
664644 k_str = [str (x ) for x in k ]
665- if "dropout" in k_str or "rngs" in k_str :
645+ is_stat = False
646+ for ks in k_str :
647+ if "dropout" in ks or "rngs" in ks :
648+ is_stat = True
649+ break
650+ if is_stat :
666651 continue
667652 filtered_eval_shapes [k ] = v
668653
0 commit comments