@@ -256,19 +256,25 @@ def load_base_wan_transformer(
256256 norm_added_q_buffer = {}
257257 print (f"DEBUG: Total keys found in checkpoint: { len (tensors )} " )
258258 for pt_key , tensor in tensors .items ():
259- if "norm_added_q" in pt_key :
260- print (f"DEBUG: Found norm_added_q key: { pt_key } " )
261- renamed_pt_key = rename_key (pt_key )
262- if "norm_added_q" in pt_key :
259+ if "norm_added_q" in pt_key and "weight" in pt_key :
263260 parts = pt_key .split ("." )
264261 try :
265- block_idx = int (parts [1 ])
262+ # Robustly find the block index (handles 'blocks.0' and 'model...blocks.0')
263+ if "blocks" in parts :
264+ block_idx_loc = parts .index ("blocks" ) + 1
265+ block_idx = int (parts [block_idx_loc ])
266+
267+ # Transpose and buffer
268+ tensor = tensor .T
269+ norm_added_q_buffer [block_idx ] = tensor
270+ else :
271+ print (f"Warning: skipped { pt_key } (no 'blocks' found)" )
266272 except ValueError :
267- print (f"DEBUG: Failed to parse index from { pt_key } " )
268- continue
269- tensor = tensor .T
270- norm_added_q_buffer [block_idx ] = tensor
273+ print (f"Warning: skipped { pt_key } (index parse error)" )
274+
275+ # Skip the standard processing for these keys
271276 continue
277+ renamed_pt_key = rename_key (pt_key )
272278 if "image_embedder" in renamed_pt_key :
273279 if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
274280 "net.2" in renamed_pt_key or "net_2" in renamed_pt_key :
@@ -299,10 +305,18 @@ def load_base_wan_transformer(
299305 flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
300306 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
301307 if norm_added_q_buffer :
302- sorted_tensors = [norm_added_q_buffer [i ] for i in sorted (norm_added_q_buffer .keys ())]
308+ # Sort by block index to ensure correct layer order
309+ sorted_keys = sorted (norm_added_q_buffer .keys ())
310+ sorted_tensors = [norm_added_q_buffer [i ] for i in sorted_keys ]
311+
312+ # Stack into shape (40, ...)
303313 stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
314+
315+ # KEY FIX: The error demanded 'kernel', so we hardcode that key here.
304316 final_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
317+
305318 flax_state_dict [final_key ] = jax .device_put (stacked_tensor , device = cpu )
319+ print (f"Successfully injected { final_key } with shape { stacked_tensor .shape } " )
306320
307321 validate_flax_state_dict (eval_shapes , flax_state_dict )
308322 flax_state_dict = unflatten_dict (flax_state_dict )
0 commit comments