Skip to content

Commit 62634b0

Browse files
committed
fix
1 parent ddfa80c commit 62634b0

1 file changed

Lines changed: 30 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def rename_for_ltx2_transformer(key):
4040
# Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
4141

4242
# Handle audio_ff.net_0.proj -> audio_ff.net_0
43-
if "audio_ff" in key and "proj" in key:
43+
# Also handle ff.net_0.proj -> ff.net_0
44+
if ("audio_ff" in key or "ff" in key) and "proj" in key:
4445
key = key.replace(".proj", "")
4546

4647
# Handle to_out.0 -> to_out for LTX2Attention
@@ -68,6 +69,34 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
6869
pass # Already handled above or matches standard format
6970

7071
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
72+
73+
# RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
74+
flax_key_str = [str(k) for k in flax_key]
75+
76+
# Helper to replace last occurrence
77+
def replace_suffix(lst, old, new):
78+
if lst and lst[-1] == old:
79+
lst[-1] = new
80+
return lst
81+
82+
# LTX-2 uses to_q, to_k, to_v, to_out, NOT query, key, value, proj_attn
83+
if "transformer_blocks" in flax_key_str:
84+
if flax_key_str[-1] == "query":
85+
flax_key_str[-1] = "to_q"
86+
elif flax_key_str[-1] == "key":
87+
flax_key_str[-1] = "to_k"
88+
elif flax_key_str[-1] == "value":
89+
flax_key_str[-1] = "to_v"
90+
91+
# For to_out, rename_key_and_reshape_tensor might map to_out_0 -> proj_attn
92+
# OR if we mapped `to_out_0` -> `to_out` manually, it keeps `to_out` but changes `weight` -> `kernel`
93+
# We just want to ensure consistency.
94+
# If it became `proj_attn`, revert it.
95+
if len(flax_key_str) >= 2 and flax_key_str[-2] == "proj_attn":
96+
# proj_attn, kernel -> to_out, kernel
97+
flax_key_str[-2] = "to_out"
98+
99+
flax_key = tuple(flax_key_str)
71100
flax_key = _tuple_str_to_int(flax_key)
72101

73102
if scan_layers and block_index is not None:

0 commit comments

Comments
 (0)