Skip to content

Commit 64cf8a4

Browse files
committed
reproduce key mapping
1 parent 4f375aa commit 64cf8a4

2 files changed

Lines changed: 3590 additions & 0 deletions

File tree

reproduce_key_mapping.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
2+
import jax
3+
import torch
4+
import re
5+
from flax import nnx
6+
from flax.traverse_util import flatten_dict, unflatten_dict
7+
from maxdiffusion.models.ltx2.ltx2_utils import rename_for_ltx2_transformer, get_key_and_value
8+
from maxdiffusion.models.modeling_flax_pytorch_utils import rename_key, rename_key_and_reshape_tensor
9+
10+
# Mock random_flax_state_dict (expected Flax keys)
11+
random_flax_state_dict = {
12+
('audio_caption_projection', 'linear_1', 'kernel'): "PLACEHOLDER",
13+
('audio_caption_projection', 'linear_1', 'bias'): "PLACEHOLDER",
14+
('transformer_blocks', 'audio_to_video_attn', 'norm_k', 'scale'): "PLACEHOLDER",
15+
('transformer_blocks', 'scale_shift_table',): "PLACEHOLDER",
16+
('transformer_blocks', '0', 'scale_shift_table'): "PLACEHOLDER", # If scanned, expected to be mapped here?
17+
}
18+
19+
# Values for "random_flax_state_dict" are not used by rename logic EXCEPT for checks relative to it.
20+
# We need to make sure we populate it enough for rename_key_and_reshape_tensor to work if it checks existence.
21+
22+
# Checkpoint keys to test
23+
checkpoint_keys = [
24+
"audio_caption_projection.linear_1.weight",
25+
"audio_caption_projection.linear_1.bias",
26+
"transformer_blocks.0.audio_to_video_attn.norm_k.weight",
27+
"transformer_blocks.0.scale_shift_table", # Expected in checkpoint? Index JSON says "transformer_blocks.0.scale_shift_table"?
28+
# JSON has: "audio_scale_shift_table" (global), and maybe block ones?
29+
# Let's check a block key from JSON if possible, but we only have global ones in snippet.
30+
# We saw "transformer_blocks.0.scale_shift_table" in debug prints?
31+
# Actually debug prints showed: "transformer_blocks.0.scale_shift_table"
32+
]
33+
34+
print("--- START DEBUG ---")
35+
36+
for pt_key in checkpoint_keys:
37+
print(f"\nProcessing Checkpoint Key: {pt_key}")
38+
39+
# 1. rename_key
40+
renamed_pt_key = rename_key(pt_key)
41+
print(f"After rename_key: {renamed_pt_key}")
42+
43+
# 2. rename_for_ltx2_transformer
44+
renamed_pt_key = rename_for_ltx2_transformer(renamed_pt_key)
45+
print(f"After rename_for_ltx2: {renamed_pt_key}")
46+
47+
pt_tuple_key = tuple(renamed_pt_key.split("."))
48+
print(f"Tuple Key: {pt_tuple_key}")
49+
50+
# 3. get_key_and_value
51+
# We need dummy tensor
52+
dummy_tensor = torch.zeros((10, 10))
53+
flax_state_dict = {} # Mock
54+
55+
# Need to simulate scan_layers=True
56+
scan_layers = True
57+
num_layers = 48
58+
59+
flax_key, flax_tensor = get_key_and_value(
60+
pt_tuple_key, dummy_tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
61+
)
62+
63+
print(f"Final Flax Key: {flax_key}")
64+
65+
# Check if match
66+
if flax_key in random_flax_state_dict:
67+
print(">> MATCH FOUUND in random_flax_state_dict")
68+
else:
69+
print(">> MISSING in random_flax_state_dict")
70+
# Try finding partial match
71+
possible = [k for k in random_flax_state_dict if k[-1] == flax_key[-1]]
72+
if possible:
73+
print(f" Did you mean: {possible}?")

0 commit comments

Comments
 (0)