@@ -245,16 +245,15 @@ def load_base_wan_transformer(
245245 tensors [k ] = torch2jax (f .get_tensor (k ))
246246 flax_state_dict = {}
247247 cpu = jax .local_devices (backend = "cpu" )[0 ]
248- flattened_dict = flatten_dict (eval_shapes )
248+ flattened_eval_shapes = flatten_dict (eval_shapes )
249249 # turn all block numbers to strings just for matching weights.
250250 # Later they will be turned back to ints.
251251 random_flax_state_dict = {}
252- for key in flattened_dict :
252+ for key in flattened_eval_shapes :
253253 string_tuple = tuple ([str (item ) for item in key ])
254- random_flax_state_dict [string_tuple ] = flattened_dict [key ]
254+ random_flax_state_dict [string_tuple ] = flattened_eval_shapes [key ]
255255 # del flattened_dict
256256 norm_added_q_buffer = {}
257- print (f"DEBUG: Total keys found in checkpoint: { len (tensors )} " )
258257 for pt_key , tensor in tensors .items ():
259258 if "norm_added_q" in pt_key and "weight" in pt_key :
260259 parts = pt_key .split ("." )
@@ -264,8 +263,8 @@ def load_base_wan_transformer(
264263 block_idx = int (parts [block_idx_loc ])
265264 tensor = tensor .T
266265 norm_added_q_buffer [block_idx ] = tensor
267- except Exception as e :
268- print ( f"Warning: skipped { pt_key } due to { e } " )
266+ except Exception :
267+ pass
269268 continue
270269
271270 renamed_pt_key = rename_key (pt_key )
@@ -302,33 +301,17 @@ def load_base_wan_transformer(
302301 sorted_keys = sorted (norm_added_q_buffer .keys ())
303302 sorted_tensors = [norm_added_q_buffer [i ] for i in sorted_keys ]
304303 stacked_tensor = jnp .stack (sorted_tensors , axis = 0 )
305-
306- target_key = None
307- print ("DEBUG: Searching eval_shapes for norm_added_q..." )
308- possible_keys = []
309-
310- for key_tuple in flattened_dict .keys ():
311- if "norm_added_q" in key_tuple :
312- possible_keys .append (key_tuple )
313-
314- if len (possible_keys ) > 0 :
315- # Pick the first one (should only be one for this specific layer)
316- target_key = possible_keys [0 ]
317- print (f"DEBUG: Found matching key in eval_shapes: { target_key } " )
318- flax_state_dict [target_key ] = jax .device_put (stacked_tensor , device = cpu )
319- else :
320- # If we still find nothing, print ALL keys to debug for the user
321- print ("CRITICAL ERROR: 'norm_added_q' NOT FOUND in eval_shapes." )
322- print ("DEBUG: Dumping sample keys from eval_shapes to help debug:" )
323- for i , k in enumerate (list (flattened_dict .keys ())[:20 ]):
324- print (f" { k } " )
325-
326- # Last resort fallback
327- manual_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
328- print (f"DEBUG: Attempting manual injection to { manual_key } " )
329- flax_state_dict [manual_key ] = jax .device_put (stacked_tensor , device = cpu )
330-
331- del flattened_dict
304+ final_key = ('blocks' , 'attn2' , 'norm_added_q' , 'kernel' )
305+ flax_state_dict [final_key ] = jax .device_put (stacked_tensor , device = cpu )
306+ print (f"DEBUG: Manually injected { final_key } into flax_state_dict" )
307+ if final_key not in flattened_eval_shapes :
308+ print (f"DEBUG: Key { final_key } missing in eval_shapes. Patching it now." )
309+ shape_struct = jax .ShapeDtypeStruct (
310+ shape = stacked_tensor .shape ,
311+ dtype = stacked_tensor .dtype
312+ )
313+ flattened_eval_shapes [final_key ] = shape_struct
314+ eval_shapes = unflatten_dict (flattened_eval_shapes )
332315
333316 validate_flax_state_dict (eval_shapes , flax_state_dict )
334317 flax_state_dict = unflatten_dict (flax_state_dict )
0 commit comments