Skip to content

Commit cd7ad6a

Browse files
committed
missing keys error
1 parent 3d0b6c5 commit cd7ad6a

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,12 @@ def load_base_wan_transformer(
276276
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
277277

278278
if "norm_added_q" in renamed_pt_key:
279-
# 1. Restore the dot before norm_added_q so JAX sees it as a submodule
280-
renamed_pt_key = renamed_pt_key.replace("attn2_norm_added_q", "attn2.norm_added_q")
279+
# 1. Restore the hierarchy: Convert flattened "attn2_norm_added_q" back to "attn2.norm_added_q"
280+
# We use a broad replace to catch it regardless of what prefix (blocks_X_) is before it.
281+
if "attn2_norm_added_q" in renamed_pt_key:
282+
renamed_pt_key = renamed_pt_key.replace("attn2_norm_added_q", "attn2.norm_added_q")
281283

282-
# 2. Force 'weight' -> 'scale' (JAX requirement)
284+
# 2. Force 'weight' -> 'scale' (Crucial: JAX explicitly asks for 'scale')
283285
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
284286
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
285287
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")

0 commit comments

Comments
 (0)