@@ -256,13 +256,13 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- if "net.0" in renamed_pt_key :
260- renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0.proj" )
261- elif "net.2" in renamed_pt_key :
262- renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2.proj" )
263- renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
264- if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
265- renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
259+ if "net.0" in renamed_pt_key :
260+ renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0.proj" )
261+ elif "net.2" in renamed_pt_key :
262+ renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2.proj" )
263+ renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
264+ if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
265+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
266266 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
267267 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
268268 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
0 commit comments