@@ -262,8 +262,6 @@ def load_base_wan_transformer(
262262 tensor = tensor .T
263263 norm_added_q_buffer [block_idx ] = tensor
264264 continue
265- if "norm_added_q" in pt_key :
266- debug_original = renamed_pt_key
267265 if "image_embedder" in renamed_pt_key :
268266 if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
269267 "net.2" in renamed_pt_key or "net_2" in renamed_pt_key :
@@ -288,17 +286,8 @@ def load_base_wan_transformer(
288286 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
289287 renamed_pt_key = renamed_pt_key .replace ("ffn.net_2" , "ffn.proj_out" )
290288 renamed_pt_key = renamed_pt_key .replace ("ffn.net_0" , "ffn.act_fn" )
291- renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
292- if "norm_added_q" in pt_key :
293- print (f"DEBUG REPORT for { pt_key } :" )
294- print (f" 1. After rename_key : { debug_original } " )
295- print (f" 2. Final Key String : { renamed_pt_key } " )
296-
297- # Test parsing
298- pt_tuple_key = tuple (renamed_pt_key .split ("." ))
299- flax_key , _ = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
300- print (f" 3. Parsed Flax Key : { flax_key } " )
301- print ("-" * 20 )
289+ if "norm2.layer_norm" not in renamed_pt_key :
290+ renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
302291 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
303292 flax_key , flax_tensor = get_key_and_value (pt_tuple_key , tensor , flax_state_dict , random_flax_state_dict , scan_layers )
304293 flax_state_dict [flax_key ] = jax .device_put (jnp .asarray (flax_tensor ), device = cpu )
0 commit comments