Skip to content

Commit 740063b

Browse files
committed
transformer weight loading
1 parent 4646119 commit 740063b

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def rename_for_ltx2_transformer(key):
107107
# Handle substrings before they are replaced by shorter patterns below
108108
key = key.replace("audio_prompt_adaln_single", "audio_prompt_adaln")
109109
key = key.replace("prompt_adaln_single", "prompt_adaln")
110+
key = key.replace("audio_prompt_scale_shift_table", "audio_scale_shift_table")
111+
key = key.replace("prompt_scale_shift_table", "scale_shift_table")
110112

111113
if "prompt_adaln" in key:
112114
key = key.replace("prompt_adaln", "caption_projection")
@@ -141,6 +143,12 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
141143
pt_tuple_key = ("transformer_blocks", str(block_index)) + pt_tuple_key[1:]
142144

143145
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
146+
147+
# Transpose back caption projections for LTX-2.3 as they are already in JAX format or shouldn't be transposed
148+
if "caption_projection" in flax_key or "audio_caption_projection" in flax_key:
149+
if "kernel" in flax_key and flax_tensor.ndim == 2:
150+
flax_tensor = flax_tensor.T
151+
144152
flax_key_str = [str(k) for k in flax_key]
145153

146154
if "scale_shift_table" in flax_key_str:

0 commit comments

Comments
 (0)