File tree Expand file tree Collapse file tree
src/maxdiffusion/models/ltx2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -250,11 +250,32 @@ def load_transformer_weights(
250250 for k in list (random_flax_state_dict .keys ())[:20 ]:
251251 print (k )
252252
253+ print ("\n DEBUG: Transformer Block 0 keys from Checkpoint:" )
254+ found_block_0 = False
255+ for k in tensors .keys ():
256+ if "transformer_blocks.0." in k or "transformer_blocks_0." in k :
257+ print (k )
258+ found_block_0 = True
259+
260+ if not found_block_0 :
261+ # Try looking for any block
262+ for k in tensors .keys ():
263+ if "transformer_blocks" in k :
264+ print (f"Sample block key: { k } " )
265+ break
266+
267+ print ("\n DEBUG: Global Norm/LN candidates in Checkpoint:" )
268+ for k in tensors .keys ():
269+ if "norm" in k .lower () or "ln" in k .lower ():
270+ if "transformer_blocks" not in k :
271+ print (k )
272+
253273 print ("\n DEBUG: Transformer Block keys from Flax Model (eval_shapes):" )
254274 for k in list (random_flax_state_dict .keys ()):
255275 k_str = str (k )
256276 if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str ):
257- print (f"EVAL_SHAPE: { k } " )
277+ # print(f"EVAL_SHAPE: {k}") # Comment out to reduce noise, we know they exist
278+ pass
258279
259280 for pt_key , tensor in tensors .items ():
260281 renamed_pt_key = rename_key (pt_key )
You can’t perform that action at this time.
0 commit comments