@@ -485,3 +485,186 @@ def load_connector_weights(
485485 jax .clear_caches ()
486486 validate_flax_state_dict (eval_shapes , flax_state_dict )
487487 return unflatten_dict (flax_state_dict )
488+
489+ def rename_for_ltx2_audio_vae (key ):
490+ # Standard VAE renaming (resblocks -> resnets, ups -> upsamplers)
491+ key = key .replace ("resblocks" , "resnets" )
492+ key = key .replace ("ups" , "upsamplers" )
493+ key = key .replace ("conv_shortcut.weight" , "conv_shortcut_layer.kernel" )
494+ key = key .replace ("conv_shortcut.bias" , "conv_shortcut_layer.bias" )
495+
496+ # Handle q, k, v, proj_out in AttnBlock
497+ if "q.weight" in key : key = key .replace ("q.weight" , "q.kernel" )
498+ if "k.weight" in key : key = key .replace ("k.weight" , "k.kernel" )
499+ if "v.weight" in key : key = key .replace ("v.weight" , "v.kernel" )
500+ if "proj_out.weight" in key : key = key .replace ("proj_out.weight" , "proj_out.kernel" )
501+
502+ # Handle conv.weight -> conv.kernel
503+ if key .endswith (".weight" ) and "conv" in key :
504+ key = key .replace (".weight" , ".kernel" )
505+
506+ return key
507+
508+
509+ def load_audio_vae_weights (
510+ pretrained_model_name_or_path : str ,
511+ eval_shapes : dict ,
512+ device : str ,
513+ hf_download : bool = True ,
514+ subfolder : str = "audio_vae"
515+ ):
516+ tensors = load_sharded_checkpoint (pretrained_model_name_or_path , subfolder , device )
517+ flax_state_dict = {}
518+ cpu = jax .local_devices (backend = "cpu" )[0 ]
519+
520+ flattened_eval = flatten_dict (eval_shapes )
521+ random_flax_state_dict = {}
522+ for key in flattened_eval :
523+ string_tuple = tuple ([str (item ) for item in key ])
524+ random_flax_state_dict [string_tuple ] = flattened_eval [key ]
525+
526+ for pt_key , tensor in tensors .items ():
527+ key = rename_for_ltx2_audio_vae (pt_key )
528+
529+ # Determine if we need to transpose (Conv weights: OHWI -> HWIO)
530+ # PyTorch Conv2d: (Out, In, H, W) -> Flax: (H, W, In, Out)
531+ # However, for 1x1 convs (like q, k, v), it might be (Out, In, 1, 1) -> (1, 1, In, Out)
532+
533+ should_transpose = False
534+ if key .endswith (".kernel" ):
535+ if tensor .ndim == 4 :
536+ should_transpose = True
537+
538+ if should_transpose :
539+ tensor = tensor .transpose (2 , 3 , 1 , 0 )
540+
541+ # Handle special keys: latents_mean, latents_std
542+ if "latents_mean" in key :
543+ # PyTorch: [C], Flax: [C] (Buffer)
544+ pass
545+ if "latents_std" in key :
546+ pass
547+
548+ # Convert key to tuple
549+ parts = key .split ("." )
550+ flax_key_parts = []
551+ for part in parts :
552+ if part .isdigit ():
553+ flax_key_parts .append (int (part ))
554+ else :
555+ flax_key_parts .append (part )
556+
557+ flax_key = tuple (flax_key_parts )
558+
559+ # Handle resnet nesting (down_blocks.0.resnets.0...)
560+ # LTX-2 Audio VAE structure in Flax might be slightly different if not using List
561+ # But we used nnx.List in the implementation, so it should match mostly.
562+ # Let's check against random_flax_state_dict if possible or rely on structure.
563+
564+ # Special handling for "mid_block" which in PT might be "mid_block.resnets.0"
565+ # but in our Flax implementation is "mid_block1", "mid_block2"
566+
567+ if "mid_block" in pt_key :
568+ # PT: mid_block.resnets.0 -> Flax: mid_block1
569+ # PT: mid_block.attentions.0 -> Flax: mid_attn
570+ # PT: mid_block.resnets.1 -> Flax: mid_block2
571+
572+ new_flax_key_parts = list (flax_key )
573+ if "resnets" in new_flax_key_parts :
574+ idx = new_flax_key_parts [new_flax_key_parts .index ("resnets" ) + 1 ]
575+ if idx == 0 :
576+ # Replace 'mid_block', 'resnets', 0 with 'mid_block1'
577+ # Warning: This is a bit fragile.
578+ pass
579+
580+ # Actually, let's map explicitly based on known structure
581+ if "mid_block.resnets.0" in pt_key :
582+ # mid_block.resnets.0.conv1.weight -> mid_block1.conv1.kernel
583+ key = key .replace ("mid_block.resnets.0" , "mid_block1" )
584+ elif "mid_block.resnets.1" in pt_key :
585+ key = key .replace ("mid_block.resnets.1" , "mid_block2" )
586+ elif "mid_block.attentions.0" in pt_key :
587+ key = key .replace ("mid_block.attentions.0" , "mid_attn" )
588+
589+ # Re-split after mid_block renaming
590+ parts = key .split ("." )
591+ flax_key_parts = []
592+ for part in parts :
593+ if part .isdigit ():
594+ flax_key_parts .append (int (part ))
595+ else :
596+ flax_key_parts .append (part )
597+ flax_key = tuple (flax_key_parts )
598+
599+ # Handle down_blocks / up_blocks
600+ # PT: down_blocks.0.resnets.0 -> Flax: down_stages.0.blocks.0
601+ # PT: down_blocks.0.attentions.0 -> Flax: down_stages.0.attentions.0
602+ # PT: down_blocks.0.downsamplers.0 -> Flax: down_stages.0.downsample
603+
604+ if "down_blocks" in key :
605+ # down_blocks.0.resnets.0 -> down_stages.0.blocks.0
606+ if "resnets" in key :
607+ key = key .replace ("down_blocks" , "down_stages" )
608+ key = key .replace ("resnets" , "blocks" )
609+ elif "attentions" in key :
610+ key = key .replace ("down_blocks" , "down_stages" )
611+ key = key .replace ("attentions" , "attns" )
612+ elif "downsamplers" in key :
613+ key = key .replace ("down_blocks" , "down_stages" )
614+ # downsamplers.0 -> downsample (since we have one downsample per stage)
615+ key = key .replace ("downsamplers.0" , "downsample" )
616+
617+ # Re-split
618+ parts = key .split ("." )
619+ flax_key_parts = []
620+ for part in parts :
621+ if part .isdigit ():
622+ flax_key_parts .append (int (part ))
623+ else :
624+ flax_key_parts .append (part )
625+ flax_key = tuple (flax_key_parts )
626+
627+ if "up_blocks" in key :
628+ # up_blocks.0.resnets.0 -> up_stages.0.blocks.0
629+ # Note: PT up_blocks are usually reversed compared to simple iteration, but
630+ # in Diffusers they correspond to levels.
631+ # Flax implementation: `up_stages` list iterates reversed(range(len(ch_mult)))
632+ # so up_stages[0] corresponds to the deepest resolution?
633+ # LTX-2 Audio VAE implementation:
634+ # for level in reversed(range(len(self.ch_mult))): ... self.up_stages.append(...)
635+ # So up_stages[0] is the first upsample stage (lowest res -> higher res).
636+ # Diffusers `up_blocks` usually go 0, 1, 2...
637+ # So it should be a direct mapping if existing logic holds.
638+
639+ if "resnets" in key :
640+ key = key .replace ("up_blocks" , "up_stages" )
641+ key = key .replace ("resnets" , "blocks" )
642+ elif "attentions" in key :
643+ key = key .replace ("up_blocks" , "up_stages" )
644+ key = key .replace ("attentions" , "attns" )
645+ elif "upsamplers" in key :
646+ key = key .replace ("up_blocks" , "up_stages" )
647+ key = key .replace ("upsamplers.0" , "upsample" )
648+
649+ # Re-split
650+ parts = key .split ("." )
651+ flax_key_parts = []
652+ for part in parts :
653+ if part .isdigit ():
654+ flax_key_parts .append (int (part ))
655+ else :
656+ flax_key_parts .append (part )
657+ flax_key = tuple (flax_key_parts )
658+
659+ flax_state_dict [flax_key ] = jax .device_put (tensor , device = cpu )
660+
661+ # Filter eval shapes to remove rngs/dropout
662+ filtered_eval_shapes = {}
663+ for k , v in flattened_eval .items ():
664+ k_str = [str (x ) for x in k ]
665+ if "dropout" in k_str or "rngs" in k_str :
666+ continue
667+ filtered_eval_shapes [k ] = v
668+
669+ validate_flax_state_dict (unflatten_dict (filtered_eval_shapes ), flax_state_dict )
670+ return unflatten_dict (flax_state_dict )
0 commit comments