Skip to content

Commit 3d0b6c5

Browse files
committed
missing keys error
1 parent d052901 commit 3d0b6c5

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,11 @@ def load_base_wan_transformer(
275275
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
276276
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
277277

278-
if "norm_added" in renamed_pt_key:
279-
if "attn2_norm_added" in renamed_pt_key:
280-
renamed_pt_key = renamed_pt_key.replace("attn2_norm_added", "attn2.norm_added")
278+
if "norm_added_q" in renamed_pt_key:
279+
# 1. Restore the dot before norm_added_q so JAX sees it as a submodule
280+
renamed_pt_key = renamed_pt_key.replace("attn2_norm_added_q", "attn2.norm_added_q")
281+
282+
# 2. Force 'weight' -> 'scale' (JAX requirement)
281283
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
282284
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
283285
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")

0 commit comments

Comments
 (0)