@@ -36,6 +36,18 @@ def rename_for_ltx2_transformer(key):
3636 # rename_key changes adaLN_modulation.1 -> adaLN_modulation_1
3737 if "adaLN_modulation_1" in key :
3838 key = key .replace ("adaLN_modulation_1" , "scale_shift_table" )
39+
40+ # Handle video_a2v_cross_attn_scale_shift_table (caption_modulator?)
41+ # Checkpoint: caption_modulator.1.weight
42+ if "caption_modulator_1" in key :
43+ key = key .replace ("caption_modulator_1" , "video_a2v_cross_attn_scale_shift_table" )
44+
45+ # Audio caption modulator?
46+ # Checkpoint: audio_caption_modulator.1.weight (Guessing name)
47+ # Let's inspect checkpoint keys for clues if this guess fails.
48+ if "audio_caption_modulator_1" in key :
49+ key = key .replace ("audio_caption_modulator_1" , "audio_a2v_cross_attn_scale_shift_table" )
50+
3951
4052 # Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
4153
@@ -68,11 +80,24 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
6880 if "transformer_blocks" in pt_tuple_key :
6981 pass # Already handled above or matches standard format
7082
83+ # Handle scale_shift_table keys - they are Params, not Linear layers, so no 'kernel' suffix needed
84+ # We might have renamed them to scale_shift_table already in rename_for_ltx2_transformer
85+ if "scale_shift_table" in pt_tuple_key [- 1 ] or "scale_shift_table" in pt_tuple_key :
86+ # if we renamed it to ends with scale_shift_table, use it directly
87+ # But rename_key_and_reshape might have added kernel?
88+ pass
89+
7190 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers )
7291
7392 # RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
7493 flax_key_str = [str (k ) for k in flax_key ]
7594
95+ # Fix scale_shift_table mapping if it got 'kernel' appended
96+ if "scale_shift_table" in flax_key_str :
97+ # if last is kernel/weight, remove it
98+ if flax_key_str [- 1 ] in ["kernel" , "weight" ]:
99+ flax_key_str .pop ()
100+
76101 # Helper to replace last occurrence
77102 def replace_suffix (lst , old , new ):
78103 if lst and lst [- 1 ] == old :
0 commit comments