Skip to content

Commit 2be0be8

Browse files
committed
debug
1 parent aee603d commit 2be0be8

1 file changed

Lines changed: 25 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,47 @@ def rename_for_ltx2_transformer(key):
2929
"""
3030
key = key.replace("patchify_proj", "proj_in")
3131
key = key.replace("audio_patchify_proj", "audio_proj_in")
32+
33+
if "caption_projection" in key:
34+
key = key.replace("caption_projection", "audio_caption_projection")
35+
36+
# Handle audio_ff.net_0.proj -> audio_ff.net_0
37+
if "audio_ff" in key and "proj" in key:
38+
key = key.replace(".proj", "")
39+
40+
# This line was redundant, keeping it as a no-op or removing it is fine.
41+
# The instruction implies it should be `return key` at the end.
3242
key = key.replace("transformer_blocks", "transformer_blocks")
3343
return key
3444

3545

3646
def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers=48):
47+
block_index = None
48+
49+
# Handle transformer_blocks_N produced by rename_key
50+
if scan_layers and len(pt_tuple_key) > 0 and "transformer_blocks_" in pt_tuple_key[0]:
51+
import re
52+
m = re.match(r"transformer_blocks_(\d+)", pt_tuple_key[0])
53+
if m:
54+
block_index = int(m.group(1))
55+
# Map transformer_blocks_N -> transformer_blocks
56+
pt_tuple_key = ("transformer_blocks",) + pt_tuple_key[1:]
57+
3758
if scan_layers:
3859
if "transformer_blocks" in pt_tuple_key:
39-
new_key = ("transformer_blocks",) + pt_tuple_key[2:]
40-
block_index = int(pt_tuple_key[1])
41-
pt_tuple_key = new_key
60+
pass # Already handled above or matches standard format
4261

4362
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
4463
flax_key = _tuple_str_to_int(flax_key)
4564

46-
if scan_layers:
65+
if scan_layers and block_index is not None:
4766
if "transformer_blocks" in flax_key:
4867
if flax_key in flax_state_dict:
4968
new_tensor = flax_state_dict[flax_key]
5069
else:
70+
# Initialize with correct shape (layers, ...)
5171
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape, dtype=flax_tensor.dtype)
72+
5273
new_tensor = new_tensor.at[block_index].set(flax_tensor)
5374
flax_tensor = new_tensor
5475

0 commit comments

Comments
 (0)