Skip to content

Commit 4534be4

Browse files
committed
fix
1 parent c95ce06 commit 4534be4

1 file changed

Lines changed: 22 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,32 @@ def load_transformer_weights(
250250
for k in list(random_flax_state_dict.keys())[:20]:
251251
print(k)
252252

253+
print("\nDEBUG: Transformer Block 0 keys from Checkpoint:")
254+
found_block_0 = False
255+
for k in tensors.keys():
256+
if "transformer_blocks.0." in k or "transformer_blocks_0." in k:
257+
print(k)
258+
found_block_0 = True
259+
260+
if not found_block_0:
261+
# Try looking for any block
262+
for k in tensors.keys():
263+
if "transformer_blocks" in k:
264+
print(f"Sample block key: {k}")
265+
break
266+
267+
print("\nDEBUG: Global Norm/LN candidates in Checkpoint:")
268+
for k in tensors.keys():
269+
if "norm" in k.lower() or "ln" in k.lower():
270+
if "transformer_blocks" not in k:
271+
print(k)
272+
253273
print("\nDEBUG: Transformer Block keys from Flax Model (eval_shapes):")
254274
for k in list(random_flax_state_dict.keys()):
255275
k_str = str(k)
256276
if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str):
257-
print(f"EVAL_SHAPE: {k}")
277+
# print(f"EVAL_SHAPE: {k}") # Comment out to reduce noise, we know they exist
278+
pass
258279

259280
for pt_key, tensor in tensors.items():
260281
renamed_pt_key = rename_key(pt_key)

0 commit comments

Comments
 (0)