@@ -487,42 +487,37 @@ def load_connector_weights(
487487 return unflatten_dict (flax_state_dict )
488488
489489def rename_for_ltx2_audio_vae (key ):
490- # Standard VAE renaming (resblocks -> resnets, ups -> upsamplers)
491- key = key .replace ("resblocks" , "resnets" )
492- key = key .replace ("ups" , "upsamplers" )
490+ # LTX-2 Audio VAE specific renaming
491+
492+ # 1. Common renames
493+ if key .endswith (".weight" ):
494+ key = key .replace (".weight" , ".kernel" )
495+
496+ # 2. Structure renames
497+ # mid.block_1 -> mid_block1
498+ key = key .replace ("mid.block_1" , "mid_block1" )
499+ key = key .replace ("mid.block_2" , "mid_block2" )
500+ key = key .replace ("mid.attn_1" , "mid_attn" ) # Assumption, but safe
493501
494- # conv_shortcut -> conv_shortcut_layer.conv (Causal)
495- key = key .replace ("conv_shortcut.weight" , "conv_shortcut_layer.conv.kernel" )
496- key = key .replace ("conv_shortcut.bias" , "conv_shortcut_layer.conv.bias" )
502+ # up.0 -> up_stages.0
503+ key = key .replace ("up." , "up_stages." )
504+ # down.0 -> down_stages.0
505+ key = key .replace ("down." , "down_stages." )
497506
498- # Handle q, k, v, proj_out in AttnBlock
499- if "q.weight" in key : key = key .replace ("q.weight" , "q.kernel" )
500- if "k.weight" in key : key = key .replace ("k.weight" , "k.kernel" )
501- if "v.weight" in key : key = key .replace ("v.weight" , "v.kernel" )
502- if "proj_out.weight" in key : key = key .replace ("proj_out.weight" , "proj_out.kernel" )
507+ # block.0 -> blocks.0
508+ key = key .replace ("block." , "blocks." )
503509
504- # Handle conv.weight -> conv.kernel
505- if key .endswith (".weight" ) and "conv" in key :
506- key = key .replace (".weight" , ".kernel" )
510+ # nin_shortcut -> conv_shortcut_layer
511+ key = key .replace ("nin_shortcut" , "conv_shortcut_layer" )
512+
513+ # In case upsample/downsample keys are just 'upsample' / 'downsample'
514+ # Check for CausalConv wrapping in upsample
515+ # If PT is 'upsample.conv.kernel' but Flax needs 'upsample.conv.conv.kernel'
516+ if "upsample.conv.kernel" in key :
517+ key = key .replace ("upsample.conv.kernel" , "upsample.conv.conv.kernel" )
518+ if "upsample.conv.bias" in key :
519+ key = key .replace ("upsample.conv.bias" , "upsample.conv.conv.bias" )
507520
508- # Inject .conv for CausalConvs
509- # Layers: conv1, conv2, conv_in, conv_out
510- # These become conv1.conv.kernel etc.
511- causal_layers = ["conv1" , "conv2" , "conv_in" , "conv_out" ]
512- for layer in causal_layers :
513- if f"{ layer } .kernel" in key :
514- key = key .replace (f"{ layer } .kernel" , f"{ layer } .conv.kernel" )
515- if f"{ layer } .bias" in key :
516- key = key .replace (f"{ layer } .bias" , f"{ layer } .conv.bias" )
517-
518- # Special handling for upsample.conv (wrapped) vs downsample.conv (not wrapped)
519- # upsamplers.0.conv -> upsample.conv.conv
520- # We do this BEFORE the loop renames upsamplers.0 -> upsample
521- if "upsamplers" in key and ".conv.kernel" in key :
522- key = key .replace (".conv.kernel" , ".conv.conv.kernel" )
523- if "upsamplers" in key and ".conv.bias" in key :
524- key = key .replace (".conv.bias" , ".conv.conv.bias" )
525-
526521 return key
527522
528523
@@ -547,6 +542,10 @@ def load_audio_vae_weights(
547542 key = rename_for_ltx2_audio_vae (pt_key )
548543
549544 # Determine if we need to transpose (Conv weights: OHWI -> HWIO)
545+ # Note: 1x1 convs might also be 4D.
546+ # Standard Flax Conv: (H, W, I, O)
547+ # Standard PyTorch Conv: (O, I, H, W)
548+
550549 should_transpose = False
551550 if key .endswith (".kernel" ):
552551 if tensor .ndim == 4 :
@@ -555,12 +554,6 @@ def load_audio_vae_weights(
555554 if should_transpose :
556555 tensor = tensor .transpose (2 , 3 , 1 , 0 )
557556
558- # Handle special keys: latents_mean, latents_std
559- if "latents_mean" in key :
560- pass
561- if "latents_std" in key :
562- pass
563-
564557 # Convert key to tuple
565558 parts = key .split ("." )
566559 flax_key_parts = []
@@ -571,70 +564,6 @@ def load_audio_vae_weights(
571564 flax_key_parts .append (part )
572565
573566 flax_key = tuple (flax_key_parts )
574-
575- if "mid_block" in pt_key :
576- if "mid_block.resnets.0" in pt_key :
577- flax_key_str = "." .join ([str (x ) for x in flax_key ])
578- flax_key_str = flax_key_str .replace ("mid_block.resnets.0" , "mid_block1" )
579- elif "mid_block.resnets.1" in pt_key :
580- flax_key_str = "." .join ([str (x ) for x in flax_key ])
581- flax_key_str = flax_key_str .replace ("mid_block.resnets.1" , "mid_block2" )
582- elif "mid_block.attentions.0" in pt_key :
583- flax_key_str = "." .join ([str (x ) for x in flax_key ])
584- flax_key_str = flax_key_str .replace ("mid_block.attentions.0" , "mid_attn" )
585- else :
586- flax_key_str = "." .join ([str (x ) for x in flax_key ])
587-
588- parts = flax_key_str .split ("." )
589- flax_key_parts = []
590- for part in parts :
591- if part .isdigit ():
592- flax_key_parts .append (int (part ))
593- else :
594- flax_key_parts .append (part )
595- flax_key = tuple (flax_key_parts )
596-
597- if "down_blocks" in key :
598- key_str = "." .join ([str (x ) for x in flax_key ])
599- if "resnets" in key_str :
600- key_str = key_str .replace ("down_blocks" , "down_stages" )
601- key_str = key_str .replace ("resnets" , "blocks" )
602- elif "attentions" in key_str :
603- key_str = key_str .replace ("down_blocks" , "down_stages" )
604- key_str = key_str .replace ("attentions" , "attns" )
605- elif "downsamplers" in key_str :
606- key_str = key_str .replace ("down_blocks" , "down_stages" )
607- key_str = key_str .replace ("downsamplers.0" , "downsample" )
608-
609- parts = key_str .split ("." )
610- flax_key_parts = []
611- for part in parts :
612- if part .isdigit ():
613- flax_key_parts .append (int (part ))
614- else :
615- flax_key_parts .append (part )
616- flax_key = tuple (flax_key_parts )
617-
618- if "up_blocks" in key :
619- key_str = "." .join ([str (x ) for x in flax_key ])
620- if "resnets" in key_str :
621- key_str = key_str .replace ("up_blocks" , "up_stages" )
622- key_str = key_str .replace ("resnets" , "blocks" )
623- elif "attentions" in key_str :
624- key_str = key_str .replace ("up_blocks" , "up_stages" )
625- key_str = key_str .replace ("attentions" , "attns" )
626- elif "upsamplers" in key_str :
627- key_str = key_str .replace ("up_blocks" , "up_stages" )
628- key_str = key_str .replace ("upsamplers.0" , "upsample" )
629-
630- parts = key_str .split ("." )
631- flax_key_parts = []
632- for part in parts :
633- if part .isdigit ():
634- flax_key_parts .append (int (part ))
635- else :
636- flax_key_parts .append (part )
637- flax_key = tuple (flax_key_parts )
638567
639568 flax_state_dict [flax_key ] = jax .device_put (tensor , device = cpu )
640569
0 commit comments