File tree Expand file tree Collapse file tree
src/maxdiffusion/models/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -256,10 +256,15 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0" )
260- renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
259+ if "net.0" in renamed_pt_key :
260+ renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0.proj" )
261+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
262+ elif "net.2" in renamed_pt_key :
263+ renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2.proj" )
264+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
261265 renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
262266 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
267+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
263268 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
264269 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
265270 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
You can’t perform that action at this time.
0 commit comments