@@ -115,10 +115,6 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
115115 # Also check 'weight' because rename_key might not have converted it to kernel if it wasn't a known Linear
116116 flax_key_str = [str (k ) for k in flax_key ]
117117
118- # DEBUG: Check specific keys
119- if "norm_k" in flax_key_str or "audio_caption_projection" in flax_key_str :
120- print (f"DEBUG: get_key_and_value mapping: { pt_tuple_key } -> { flax_key_str } " )
121-
122118 if flax_key_str [- 1 ] in ["kernel" , "weight" ]:
123119 # Try replacing with scale and check if it exists in random_flax_state_dict
124120 temp_key_str = flax_key_str [:- 1 ] + ["scale" ]
@@ -298,47 +294,10 @@ def load_transformer_weights(
298294 string_tuple = tuple ([str (item ) for item in key ])
299295 random_flax_state_dict [string_tuple ] = flattened_dict [key ]
300296
301- # DEBUG: Print keys to understand mapping
302- print ("DEBUG: Top 20 keys from Checkpoint (tensors):" )
303- for k in list (tensors .keys ())[:20 ]:
304- print (k )
305-
306- print ("DEBUG: NON-BLOCK keys in Checkpoint:" )
307- for k in tensors .keys ():
308- if "transformer_blocks" not in k :
309- print (k )
310-
311- print ("\n DEBUG: Top 20 keys from Flax Model (eval_shapes):" )
312- for k in list (random_flax_state_dict .keys ())[:20 ]:
313- print (k )
314-
315- print ("\n DEBUG: Transformer Block 0 keys from Checkpoint:" )
316- found_block_0 = False
317- for k in tensors .keys ():
318- if "transformer_blocks.0." in k or "transformer_blocks_0." in k :
319- print (k )
320- found_block_0 = True
321-
322- if not found_block_0 :
323- # Try looking for any block
324- for k in tensors .keys ():
325- if "transformer_blocks" in k :
326- print (f"Sample block key: { k } " )
327- break
328-
329- print ("\n DEBUG: Global Norm/LN candidates in Checkpoint:" )
330- for k in tensors .keys ():
331- if "norm" in k .lower () or "ln" in k .lower ():
332- if "transformer_blocks" not in k :
333- print (k )
297+ for key in flattened_dict :
298+ string_tuple = tuple ([str (item ) for item in key ])
299+ random_flax_state_dict [string_tuple ] = flattened_dict [key ]
334300
335- print ("\n DEBUG: Transformer Block keys from Flax Model (eval_shapes):" )
336- for k in list (random_flax_state_dict .keys ()):
337- k_str = str (k )
338- if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str ):
339- print (f"EVAL_SHAPE: { k } " )
340- pass
341-
342301 for pt_key , tensor in tensors .items ():
343302 renamed_pt_key = rename_key (pt_key )
344303 renamed_pt_key = rename_for_ltx2_transformer (renamed_pt_key )
@@ -392,15 +351,6 @@ def load_vae_weights(
392351 cpu = jax .local_devices (backend = "cpu" )[0 ]
393352 flattened_eval = flatten_dict (eval_shapes )
394353
395- # DEBUG: Print keys to understand mapping
396- print ("DEBUG: Top 20 keys from VAE Checkpoint (tensors):" )
397- for k in list (tensors .keys ())[:20 ]:
398- print (k )
399-
400- flax_state_dict = {}
401- cpu = jax .local_devices (backend = "cpu" )[0 ]
402- flattened_eval = flatten_dict (eval_shapes )
403-
404354 random_flax_state_dict = {}
405355 for key in flattened_eval :
406356 string_tuple = tuple ([str (item ) for item in key ])
0 commit comments