Skip to content

Commit 6b978af

Browse files
committed
missing keys error
1 parent d55c680 commit 6b978af

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,12 @@ def load_base_wan_transformer(
275275
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
276276
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
277277

278-
if "attn2.norm_added_q" in renamed_pt_key:
279-
if renamed_pt_key.endswith(".weight") or renamed_pt_key.endswith(".kernel"):
280-
renamed_pt_key = renamed_pt_key.rsplit(".", 1)[0] + ".scale"
278+
if ".attn2.norm_added_q." in renamed_pt_key:
279+
if renamed_pt_key.endswith(".weight"):
280+
renamed_pt_key = renamed_pt_key[:-len(".weight")] + ".scale"
281+
elif renamed_pt_key.endswith(".kernel"):
282+
renamed_pt_key = renamed_pt_key[:-len(".kernel")] + ".scale"
283+
281284
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
282285
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
283286
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

0 commit comments

Comments
 (0)