@@ -40,7 +40,8 @@ def rename_for_ltx2_transformer(key):
4040 # Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
4141
4242 # Handle audio_ff.net_0.proj -> audio_ff.net_0
43- if "audio_ff" in key and "proj" in key :
43+ # Also handle ff.net_0.proj -> ff.net_0
44+ if ("audio_ff" in key or "ff" in key ) and "proj" in key :
4445 key = key .replace (".proj" , "" )
4546
4647 # Handle to_out.0 -> to_out for LTX2Attention
@@ -68,6 +69,34 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
6869 pass # Already handled above or matches standard format
6970
7071 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers )
72+
73+ # RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
74+ flax_key_str = [str (k ) for k in flax_key ]
75+
76+ # Helper to replace last occurrence
77+ def replace_suffix (lst , old , new ):
78+ if lst and lst [- 1 ] == old :
79+ lst [- 1 ] = new
80+ return lst
81+
82+ # LTX-2 uses to_q, to_k, to_v, to_out, NOT query, key, value, proj_attn
83+ if "transformer_blocks" in flax_key_str :
84+ if flax_key_str [- 1 ] == "query" :
85+ flax_key_str [- 1 ] = "to_q"
86+ elif flax_key_str [- 1 ] == "key" :
87+ flax_key_str [- 1 ] = "to_k"
88+ elif flax_key_str [- 1 ] == "value" :
89+ flax_key_str [- 1 ] = "to_v"
90+
91+ # For to_out, rename_key_and_reshape_tensor might map to_out_0 -> proj_attn
92+ # OR if we mapped `to_out_0` -> `to_out` manually, it keeps `to_out` but changes `weight` -> `kernel`
93+ # We just want to ensure consistency.
94+ # If it became `proj_attn`, revert it.
95+ if len (flax_key_str ) >= 2 and flax_key_str [- 2 ] == "proj_attn" :
96+ # proj_attn, kernel -> to_out, kernel
97+ flax_key_str [- 2 ] = "to_out"
98+
99+ flax_key = tuple (flax_key_str )
71100 flax_key = _tuple_str_to_int (flax_key )
72101
73102 if scan_layers and block_index is not None :
0 commit comments