Skip to content

Commit 3f25f69

Browse files
committed
fix
1 parent 3579b56 commit 3f25f69

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def rename_for_ltx2_transformer(key):
3333

3434
# Handle scale_shift_table
3535
# PyTorch: adaLN_modulation.1.weight/bias -> scale_shift_table
36-
if "adaLN_modulation.1" in key:
37-
key = key.replace("adaLN_modulation.1", "scale_shift_table")
36+
# rename_key changes adaLN_modulation.1 -> adaLN_modulation_1
37+
if "adaLN_modulation_1" in key:
38+
key = key.replace("adaLN_modulation_1", "scale_shift_table")
3839

3940
# Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
4041

@@ -43,8 +44,9 @@ def rename_for_ltx2_transformer(key):
4344
key = key.replace(".proj", "")
4445

4546
# Handle to_out.0 -> to_out for LTX2Attention
46-
if "to_out.0" in key:
47-
key = key.replace("to_out.0", "to_out")
47+
# rename_key changes to_out.0 -> to_out_0
48+
if "to_out_0" in key:
49+
key = key.replace("to_out_0", "to_out")
4850

4951
return key
5052

@@ -150,6 +152,11 @@ def load_transformer_weights(
150152
print("DEBUG: Top 20 keys from Checkpoint (tensors):")
151153
for k in list(tensors.keys())[:20]:
152154
print(k)
155+
156+
print("DEBUG: NON-BLOCK keys in Checkpoint:")
157+
for k in tensors.keys():
158+
if "transformer_blocks" not in k:
159+
print(k)
153160

154161

155162
print("\nDEBUG: Top 20 keys from Flax Model (eval_shapes):")

0 commit comments

Comments
 (0)