We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a34529f commit aee603dCopy full SHA for aee603d
1 file changed
src/maxdiffusion/models/ltx2/ltx2_utils.py
@@ -118,6 +118,15 @@ def load_transformer_weights(
118
for key in flattened_dict:
119
string_tuple = tuple([str(item) for item in key])
120
random_flax_state_dict[string_tuple] = flattened_dict[key]
121
+
122
+ # DEBUG: Print keys to understand mapping
123
+ print("DEBUG: Top 20 keys from Checkpoint (tensors):")
124
+ for k in list(tensors.keys())[:20]:
125
+ print(k)
126
127
+ print("\nDEBUG: Top 20 keys from Flax Model (eval_shapes):")
128
+ for k in list(random_flax_state_dict.keys())[:20]:
129
130
131
for pt_key, tensor in tensors.items():
132
renamed_pt_key = rename_key(pt_key)
0 commit comments