Skip to content

Commit 40ebfdc

Browse files
committed
wan_utils.py fixed
1 parent 2f0450e commit 40ebfdc

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,13 @@ def load_base_wan_transformer(
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258258
if "image_embedder" in renamed_pt_key:
259-
if "net.0" in renamed_pt_key:
260-
renamed_pt_key = renamed_pt_key.replace("net.0", "net_0.proj")
261-
elif "net.2" in renamed_pt_key:
262-
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2.proj")
263-
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
264-
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
265-
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
259+
if "net.0" in renamed_pt_key:
260+
renamed_pt_key = renamed_pt_key.replace("net.0", "net_0.proj")
261+
elif "net.2" in renamed_pt_key:
262+
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2.proj")
263+
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
264+
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
265+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
266266
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
267267
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
268268
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

0 commit comments

Comments
 (0)