@@ -255,6 +255,8 @@ def load_base_wan_transformer(
255255 del flattened_dict
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258+ if "norm_added_q" in pt_key :
259+ debug_original = renamed_pt_key
258260 if "image_embedder" in renamed_pt_key :
259261 if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
260262 "net.2" in renamed_pt_key or "net_2" in renamed_pt_key :
@@ -276,16 +278,27 @@ def load_base_wan_transformer(
276278 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
277279
278280 if "norm_added_q" in renamed_pt_key :
279- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
280- tensor = tensor .T
281+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
282+ tensor = tensor .T
281283 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
284+
282285
283286 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
284287 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
285288 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
286289 renamed_pt_key = renamed_pt_key .replace ("ffn.net_2" , "ffn.proj_out" )
287290 renamed_pt_key = renamed_pt_key .replace ("ffn.net_0" , "ffn.act_fn" )
288291 renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
292+ if "norm_added_q" in pt_key :
293+ print (f"DEBUG REPORT for { pt_key } :" )
294+ print (f" 1. After rename_key : { debug_original } " )
295+ print (f" 2. Final Key String : { renamed_pt_key } " )
296+
297+ # Test parsing
298+ pt_tuple_key = tuple (renamed_pt_key .split ("." ))
299+ flax_key , _ = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
300+ print (f" 3. Parsed Flax Key : { flax_key } " )
301+ print ("-" * 20 )
289302 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
290303 flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
291304 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
0 commit comments