Skip to content

Commit aee603d

Browse files
committed
debug
1 parent a34529f commit aee603d

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,15 @@ def load_transformer_weights(
118118
for key in flattened_dict:
119119
string_tuple = tuple([str(item) for item in key])
120120
random_flax_state_dict[string_tuple] = flattened_dict[key]
121+
122+
# DEBUG: Print keys to understand mapping
123+
print("DEBUG: Top 20 keys from Checkpoint (tensors):")
124+
for k in list(tensors.keys())[:20]:
125+
print(k)
126+
127+
print("\nDEBUG: Top 20 keys from Flax Model (eval_shapes):")
128+
for k in list(random_flax_state_dict.keys())[:20]:
129+
print(k)
121130

122131
for pt_key, tensor in tensors.items():
123132
renamed_pt_key = rename_key(pt_key)

0 commit comments

Comments
 (0)