Skip to content

Commit 3af836f

Browse files
committed
dim mismatch fix
1 parent a222767 commit 3af836f

1 file changed

Lines changed: 35 additions & 16 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -256,22 +256,41 @@ def load_base_wan_transformer(
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258258
if "image_embedder" in renamed_pt_key:
259-
if "net.0.proj" in renamed_pt_key:
260-
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
261-
elif "net_0.proj" in renamed_pt_key:
262-
renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0")
263-
if "net.2" in renamed_pt_key:
264-
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
265-
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
266-
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
267-
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
268-
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
269-
elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key:
270-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
271-
if "norm" in renamed_pt_key and "image_embedder" not in renamed_pt_key:
272-
if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key:
273-
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
274-
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
259+
# 1. Fix Shape Mismatch: Transpose Dense Layers
260+
# CHECK FOR BOTH VERSIONS ("net.0"/"net_0" and "net.2"/"net_2")
261+
# This ensures the transpose happens regardless of what rename_key did.
262+
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
263+
"net.2" in renamed_pt_key or "net_2" in renamed_pt_key:
264+
tensor = tensor.t() # <--- This will now correctly execute
265+
266+
# 2. FIX net_0: Strip .proj (Handle both raw and renamed keys)
267+
if "net.0.proj" in renamed_pt_key:
268+
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
269+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
270+
elif "net_0.proj" in renamed_pt_key:
271+
renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0")
272+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
273+
274+
# 3. FIX net_2: Ensure naming and kernel (Handle both raw and renamed keys)
275+
if "net.2" in renamed_pt_key:
276+
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
277+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
278+
elif "net_2" in renamed_pt_key:
279+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
280+
281+
# 4. FIX Norms: Add .layer_norm and force 'scale'
282+
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
283+
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
284+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
285+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
286+
287+
# --- FIX START: Missing Key 'norm_added_q' Fix ---
288+
if "norm_added" in renamed_pt_key:
289+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
290+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
291+
292+
if "attn2_norm_added" in renamed_pt_key:
293+
renamed_pt_key = renamed_pt_key.replace("attn2_norm_added", "attn2.norm_added")
275294
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
276295
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
277296
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

0 commit comments

Comments
 (0)