@@ -29,26 +29,47 @@ def rename_for_ltx2_transformer(key):
2929 """
3030 key = key .replace ("patchify_proj" , "proj_in" )
3131 key = key .replace ("audio_patchify_proj" , "audio_proj_in" )
32+
33+ if "caption_projection" in key :
34+ key = key .replace ("caption_projection" , "audio_caption_projection" )
35+
36+ # Handle audio_ff.net_0.proj -> audio_ff.net_0
37+ if "audio_ff" in key and "proj" in key :
38+ key = key .replace (".proj" , "" )
39+
40+ # This line was redundant, keeping it as a no-op or removing it is fine.
41+ # The instruction implies it should be `return key` at the end.
3242 key = key .replace ("transformer_blocks" , "transformer_blocks" )
3343 return key
3444
3545
3646def get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers , num_layers = 48 ):
47+ block_index = None
48+
49+ # Handle transformer_blocks_N produced by rename_key
50+ if scan_layers and len (pt_tuple_key ) > 0 and "transformer_blocks_" in pt_tuple_key [0 ]:
51+ import re
52+ m = re .match (r"transformer_blocks_(\d+)" , pt_tuple_key [0 ])
53+ if m :
54+ block_index = int (m .group (1 ))
55+ # Map transformer_blocks_N -> transformer_blocks
56+ pt_tuple_key = ("transformer_blocks" ,) + pt_tuple_key [1 :]
57+
3758 if scan_layers :
3859 if "transformer_blocks" in pt_tuple_key :
39- new_key = ("transformer_blocks" ,) + pt_tuple_key [2 :]
40- block_index = int (pt_tuple_key [1 ])
41- pt_tuple_key = new_key
60+ pass # Already handled above or matches standard format
4261
4362 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers )
4463 flax_key = _tuple_str_to_int (flax_key )
4564
46- if scan_layers :
65+ if scan_layers and block_index is not None :
4766 if "transformer_blocks" in flax_key :
4867 if flax_key in flax_state_dict :
4968 new_tensor = flax_state_dict [flax_key ]
5069 else :
70+ # Initialize with correct shape (layers, ...)
5171 new_tensor = jnp .zeros ((num_layers ,) + flax_tensor .shape , dtype = flax_tensor .dtype )
72+
5273 new_tensor = new_tensor .at [block_index ].set (flax_tensor )
5374 flax_tensor = new_tensor
5475
0 commit comments