Skip to content

Commit 1878eab

Browse files
committed
fix
1 parent 3f25f69 commit 1878eab

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,13 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
8080

8181
new_tensor = new_tensor.at[block_index].set(flax_tensor)
8282
flax_tensor = new_tensor
83-
83+
84+
# DEBUG TRACE
85+
if "audio_ff" in str(flax_key) and "kernel" in str(flax_key) and block_index == 18:
86+
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (Block 18)")
87+
if "to_out" in str(flax_key) and "kernel" in str(flax_key) and block_index == 18 and "attn1" in str(flax_key):
88+
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (Block 18 attn1)")
89+
8490
return flax_key, flax_tensor
8591

8692
def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device):
@@ -173,6 +179,11 @@ def load_transformer_weights(
173179
renamed_pt_key = rename_key(pt_key)
174180
renamed_pt_key = rename_for_ltx2_transformer(renamed_pt_key)
175181

182+
# DEBUG: Check intermediate rename
183+
if "audio_ff.net.0.proj" in pt_key:
184+
# This might spam, but good to see once
185+
pass
186+
176187
pt_tuple_key = tuple(renamed_pt_key.split("."))
177188

178189
flax_key, flax_tensor = get_key_and_value(

0 commit comments

Comments
 (0)