@@ -252,28 +252,22 @@ def load_base_wan_transformer(
252252 for key in flattened_dict :
253253 string_tuple = tuple ([str (item ) for item in key ])
254254 random_flax_state_dict [string_tuple ] = flattened_dict [key ]
255- del flattened_dict
255+ # del flattened_dict
256256 norm_added_q_buffer = {}
257257 print (f"DEBUG: Total keys found in checkpoint: { len (tensors )} " )
258258 for pt_key , tensor in tensors .items ():
259259 if "norm_added_q" in pt_key and "weight" in pt_key :
260260 parts = pt_key .split ("." )
261261 try :
262- # Robustly find the block index (handles 'blocks.0' and 'model...blocks.0')
263262 if "blocks" in parts :
264263 block_idx_loc = parts .index ("blocks" ) + 1
265264 block_idx = int (parts [block_idx_loc ])
266-
267- # Transpose and buffer
268265 tensor = tensor .T
269266 norm_added_q_buffer [block_idx ] = tensor
270- else :
271- print (f"Warning: skipped { pt_key } (no 'blocks' found)" )
272- except ValueError :
273- print (f"Warning: skipped { pt_key } (index parse error)" )
274-
275- # Skip the standard processing for these keys
267+ except Exception as e :
268+ print (f"Warning: skipped { pt_key } due to { e } " )
276269 continue
270+
277271 renamed_pt_key = rename_key (pt_key )
278272 if "image_embedder" in renamed_pt_key :
279273 if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
@@ -305,18 +299,29 @@ def load_base_wan_transformer(
305299 flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
306300 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
307301 if norm_added_q_buffer :
308- # Sort by block index to ensure correct layer order
309302 sorted_keys = sorted (norm_added_q_buffer .keys ())
310303 sorted_tensors = [norm_added_q_buffer [i ] for i in sorted_keys ]
311-
312- # Stack into shape (40, ...)
313304 stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
305+
306+ target_key = None
307+
308+ for key_tuple in flattened_dict .keys ():
309+ # Check if this tuple looks like what we want
310+ # We check if it ends with 'norm_added_q' and 'kernel'
311+ if len (key_tuple ) >= 2 and key_tuple [- 2 :] == ('norm_added_q' , 'kernel' ):
312+ target_key = key_tuple
313+ break
314+ if target_key :
315+ print (f"DEBUG: Found authoritative key in eval_shapes: { target_key } " )
316+ flax_state_dict [target_key ] = jax .device_put (stacked_tensor , device = cpu )
317+ print (f"Successfully injected norm_added_q with shape { stacked_tensor .shape } " )
318+ else :
319+ # Fallback (should typically not happen if error message was correct)
320+ print ("CRITICAL WARNING: Could not find norm_added_q key in eval_shapes! Using manual fallback." )
321+ manual_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
322+ flax_state_dict [manual_key ] = jax .device_put (stacked_tensor , device = cpu )
314323
315- # KEY FIX: The error demanded 'kernel', so we hardcode that key here.
316- final_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
317-
318- flax_state_dict [final_key ] = jax .device_put (stacked_tensor , device = cpu )
319- print (f"Successfully injected { final_key } with shape { stacked_tensor .shape } " )
324+ del flattened_dict
320325
321326 validate_flax_state_dict (eval_shapes , flax_state_dict )
322327 flax_state_dict = unflatten_dict (flax_state_dict )
0 commit comments