Skip to content

Commit 28f2e9e

Browse files
committed
missing keys error
1 parent cd7ad6a commit 28f2e9e

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,17 @@ def load_base_wan_transformer(
276276
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
277277

278278
if "norm_added_q" in renamed_pt_key:
279-
# 1. Restore the hierarchy: Convert flattened "attn2_norm_added_q" back to "attn2.norm_added_q"
280-
# We use a broad replace to catch it regardless of what prefix (blocks_X_) is before it.
279+
# 1. Structural Fix: Ensure 'attn2' is separated from the block index
280+
# 'blocks_0_attn2' -> 'blocks_0.attn2'
281+
if "_attn2" in renamed_pt_key:
282+
renamed_pt_key = renamed_pt_key.replace("_attn2", ".attn2")
283+
284+
# 2. Restore hierarchy for the norm itself
285+
# 'attn2_norm_added_q' -> 'attn2.norm_added_q'
281286
if "attn2_norm_added_q" in renamed_pt_key:
282287
renamed_pt_key = renamed_pt_key.replace("attn2_norm_added_q", "attn2.norm_added_q")
283288

284-
# 2. Force 'weight' -> 'scale' (Crucial: JAX explicitly asks for 'scale')
289+
# 3. Force 'weight' -> 'scale'
285290
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
286291
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
287292
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")

0 commit comments

Comments
 (0)