File tree Expand file tree Collapse file tree
src/maxdiffusion/models/ltx2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -565,6 +565,23 @@ def load_audio_vae_weights(
565565
566566 flax_key = tuple (flax_key_parts )
567567
568+ # Reverse up_stages indices if present
569+ if "up_stages" in flax_key :
570+ # Find index of 'up_stages'
571+ try :
572+ up_stages_idx = flax_key .index ("up_stages" )
573+ # The integer index follows "up_stages"
574+ if up_stages_idx + 1 < len (flax_key ):
575+ stage_idx = flax_key [up_stages_idx + 1 ]
576+ if isinstance (stage_idx , int ):
577+ # Assuming 3 stages (0, 1, 2)
578+ # Map 0 -> 2, 1 -> 1, 2 -> 0
579+ new_stage_idx = 2 - stage_idx
580+ flax_key_parts [up_stages_idx + 1 ] = new_stage_idx
581+ flax_key = tuple (flax_key_parts )
582+ except ValueError :
583+ pass
584+
568585 flax_state_dict [flax_key ] = jax .device_put (tensor , device = cpu )
569586
570587 # Filter eval shapes to remove rngs/dropout
You can’t perform that action at this time.
0 commit comments