@@ -487,37 +487,25 @@ def load_connector_weights(
487487 return unflatten_dict (flax_state_dict )
488488
489489def rename_for_ltx2_audio_vae (key ):
490- # LTX-2 Audio VAE specific renaming
491-
492- # 1. Common renames
493490 if key .endswith (".weight" ):
494491 key = key .replace (".weight" , ".kernel" )
495-
496- # 2. Structure renames
497- # mid.block_1 -> mid_block1
492+
498493 key = key .replace ("mid.block_1" , "mid_block1" )
499494 key = key .replace ("mid.block_2" , "mid_block2" )
500- key = key .replace ("mid.attn_1" , "mid_attn" ) # Assumption, but safe
495+ key = key .replace ("mid.attn_1" , "mid_attn" )
501496
502- # up.0 -> up_stages.0
503497 key = key .replace ("up." , "up_stages." )
504- # down.0 -> down_stages.0
505498 key = key .replace ("down." , "down_stages." )
506499
507- # block.0 -> blocks.0
508500 key = key .replace ("block." , "blocks." )
509501
510- # nin_shortcut -> conv_shortcut_layer
511502 key = key .replace ("nin_shortcut" , "conv_shortcut_layer" )
512503
513- # In case upsample/downsample keys are just 'upsample' / 'downsample'
514- # Check for CausalConv wrapping in upsample
515- # If PT is 'upsample.conv.kernel' but Flax needs 'upsample.conv.conv.kernel'
516504 if "upsample.conv.kernel" in key :
517505 key = key .replace ("upsample.conv.kernel" , "upsample.conv.conv.kernel" )
518506 if "upsample.conv.bias" in key :
519507 key = key .replace ("upsample.conv.bias" , "upsample.conv.conv.bias" )
520-
508+
521509 return key
522510
523511
@@ -541,20 +529,13 @@ def load_audio_vae_weights(
541529 for pt_key , tensor in tensors .items ():
542530 key = rename_for_ltx2_audio_vae (pt_key )
543531
544- # Determine if we need to transpose (Conv weights: OHWI -> HWIO)
545- # Note: 1x1 convs might also be 4D.
546- # Standard Flax Conv: (H, W, I, O)
547- # Standard PyTorch Conv: (O, I, H, W)
548-
549532 should_transpose = False
550533 if key .endswith (".kernel" ):
551534 if tensor .ndim == 4 :
552535 should_transpose = True
553536
554537 if should_transpose :
555538 tensor = tensor .transpose (2 , 3 , 1 , 0 )
556-
557- # Convert key to tuple
558539 parts = key .split ("." )
559540 flax_key_parts = []
560541 for part in parts :
@@ -565,29 +546,19 @@ def load_audio_vae_weights(
565546
566547 flax_key = tuple (flax_key_parts )
567548
568- # Reverse up_stages indices if present
569549 if "up_stages" in flax_key :
570- # Find index of 'up_stages'
571550 try :
572551 up_stages_idx = flax_key .index ("up_stages" )
573- # The integer index follows "up_stages"
574552 if up_stages_idx + 1 < len (flax_key ):
575553 stage_idx = flax_key [up_stages_idx + 1 ]
576554 if isinstance (stage_idx , int ):
577- # Assuming 3 stages (0, 1, 2)
578- # Map 0 -> 2, 1 -> 1, 2 -> 0
579555 new_stage_idx = 2 - stage_idx
580- if "upsample" in flax_key :
581- # print(f"DEBUG REVERSAL: {flax_key} -> stage_idx={stage_idx} -> new={new_stage_idx}")
582- pass
583556 flax_key_parts [up_stages_idx + 1 ] = new_stage_idx
584557 flax_key = tuple (flax_key_parts )
585558 except ValueError :
586559 pass
587560
588561 flax_state_dict [flax_key ] = jax .device_put (tensor , device = cpu )
589-
590- # Filter eval shapes to remove rngs/dropout
591562 filtered_eval_shapes = {}
592563 for k , v in flattened_eval .items ():
593564 k_str = [str (x ) for x in k ]
0 commit comments