@@ -304,22 +304,29 @@ def load_base_wan_transformer(
304304 stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
305305
306306 target_key = None
307+ print ("DEBUG: Searching eval_shapes for norm_added_q..." )
308+ possible_keys = []
307309
308310 for key_tuple in flattened_dict .keys ():
309- # Check if this tuple looks like what we want
310- # We check if it ends with 'norm_added_q' and 'kernel'
311- if len (key_tuple ) >= 2 and key_tuple [- 2 :] == ('norm_added_q' , 'kernel' ):
312- target_key = key_tuple
313- break
314- if target_key :
315- print (f"DEBUG: Found authoritative key in eval_shapes: { target_key } " )
316- flax_state_dict [target_key ] = jax .device_put (stacked_tensor , device = cpu )
317- print (f"Successfully injected norm_added_q with shape { stacked_tensor .shape } " )
311+ if "norm_added_q" in key_tuple :
312+ possible_keys .append (key_tuple )
313+
314+ if len (possible_keys ) > 0 :
315+ # Pick the first one (should only be one for this specific layer)
316+ target_key = possible_keys [0 ]
317+ print (f"DEBUG: Found matching key in eval_shapes: { target_key } " )
318+ flax_state_dict [target_key ] = jax .device_put (stacked_tensor , device = cpu )
318319 else :
319- # Fallback (should typically not happen if error message was correct)
320- print ("CRITICAL WARNING: Could not find norm_added_q key in eval_shapes! Using manual fallback." )
321- manual_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
322- flax_state_dict [manual_key ] = jax .device_put (stacked_tensor , device = cpu )
320+ # If we still find nothing, print ALL keys to debug for the user
321+ print ("CRITICAL ERROR: 'norm_added_q' NOT FOUND in eval_shapes." )
322+ print ("DEBUG: Dumping sample keys from eval_shapes to help debug:" )
323+ for i , k in enumerate (list (flattened_dict .keys ())[:20 ]):
324+ print (f" { k } " )
325+
326+ # Last resort fallback
327+ manual_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
328+ print (f"DEBUG: Attempting manual injection to { manual_key } " )
329+ flax_state_dict [manual_key ] = jax .device_put (stacked_tensor , device = cpu )
323330
324331 del flattened_dict
325332
0 commit comments