@@ -33,8 +33,9 @@ def rename_for_ltx2_transformer(key):
3333
3434 # Handle scale_shift_table
3535 # PyTorch: adaLN_modulation.1.weight/bias -> scale_shift_table
36- if "adaLN_modulation.1" in key :
37- key = key .replace ("adaLN_modulation.1" , "scale_shift_table" )
36+ # rename_key changes adaLN_modulation.1 -> adaLN_modulation_1
37+ if "adaLN_modulation_1" in key :
38+ key = key .replace ("adaLN_modulation_1" , "scale_shift_table" )
3839
3940 # Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
4041
@@ -43,8 +44,9 @@ def rename_for_ltx2_transformer(key):
4344 key = key .replace (".proj" , "" )
4445
4546 # Handle to_out.0 -> to_out for LTX2Attention
46- if "to_out.0" in key :
47- key = key .replace ("to_out.0" , "to_out" )
47+ # rename_key changes to_out.0 -> to_out_0
48+ if "to_out_0" in key :
49+ key = key .replace ("to_out_0" , "to_out" )
4850
4951 return key
5052
@@ -150,6 +152,11 @@ def load_transformer_weights(
150152 print ("DEBUG: Top 20 keys from Checkpoint (tensors):" )
151153 for k in list (tensors .keys ())[:20 ]:
152154 print (k )
155+
156+ print ("DEBUG: NON-BLOCK keys in Checkpoint:" )
157+ for k in tensors .keys ():
158+ if "transformer_blocks" not in k :
159+ print (k )
153160
154161
155162 print ("\n DEBUG: Top 20 keys from Flax Model (eval_shapes):" )
0 commit comments