@@ -171,17 +171,24 @@ def load_transformer_weights(
171171
172172 print ("\n DEBUG: Transformer Block keys from Flax Model (eval_shapes):" )
173173 for k in list (random_flax_state_dict .keys ()):
174- if "transformer_blocks" in k and "attn1" in k :
175- print (k )
176- break
174+ k_str = str (k )
175+ if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str ):
176+ print (f"EVAL_SHAPE: { k } " )
177+ if "proj_out" in k_str or "norm_out" in k_str :
178+ print (f"EVAL_SHAPE GLOBAL: { k } " )
179+
180+ # Search for norm in tensors
181+ print ("\n DEBUG: Search 'norm' in checkpoint keys:" )
182+ for k in tensors .keys ():
183+ if "norm" in k and "transformer_blocks" not in k :
184+ print (f"CKPT norm: { k } " )
177185
178186 for pt_key , tensor in tensors .items ():
179187 renamed_pt_key = rename_key (pt_key )
180188 renamed_pt_key = rename_for_ltx2_transformer (renamed_pt_key )
181189
182190 # DEBUG: Check intermediate rename
183191 if "audio_ff.net.0.proj" in pt_key :
184- # This might spam, but good to see once
185192 pass
186193
187194 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
@@ -190,6 +197,13 @@ def load_transformer_weights(
190197 pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers , num_layers
191198 )
192199
200+ # DEBUG: Trace proj_out
201+ if "proj_out" in str (flax_key ) and "bias" in str (flax_key ):
202+ print (f"DEBUG: Trace proj_out: { pt_key } -> { flax_key } " )
203+ # Check if added to dict
204+ # It acts global so it should be added below
205+
206+
193207 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
194208
195209 validate_flax_state_dict (eval_shapes , flax_state_dict )
0 commit comments