Skip to content

Commit 62c4d92

Browse files
committed
wan_utils.py fixed
1 parent d021c9b commit 62c4d92

1 file changed

Lines changed: 17 additions & 16 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -256,33 +256,34 @@ def load_base_wan_transformer(
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258258
if "image_embedder" in renamed_pt_key:
259-
# 1. Handle Layer 0: "net.0" -> "net_0" (PRESERVE .proj)
260-
# Source: ...ff.net.0.proj.weight -> Target: ...ff.net_0.proj.kernel
259+
# 1. Handle Layer 0: PyTorch has "net.0.proj" -> JAX wants "net_0.proj"
260+
# We just replace the separator "net.0" -> "net_0"
261261
if "net.0" in renamed_pt_key:
262262
renamed_pt_key = renamed_pt_key.replace("net.0", "net_0")
263-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
264263

265-
# 2. Handle Layer 2: "net.2" -> "net_2"
266-
# Source: ...ff.net.2.weight -> Target: ...ff.net_2.kernel
264+
# 2. Handle Layer 2: PyTorch has "net.2" (NO proj) -> JAX likely wants "net_2.proj"
265+
# We force the addition of ".proj" to match the symmetric JAX structure
267266
elif "net.2" in renamed_pt_key:
268-
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
269-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
267+
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2.proj")
270268

271269
# 3. Handle Norm1: "norm1" -> "norm1.layer_norm"
272270
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
273271

274-
# 4. Force Norms to use "scale"
272+
# 4. Fix Parameter Names:
273+
# Norms (norm1, norm2) -> force 'scale'
275274
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
276275
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
277276
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
278-
279-
# 5. Global Norm Fix (Fixes 'norm_added_q', 'norm_k', etc.)
280-
# Any key containing 'norm' that ends in 'weight'/'kernel' should likely be 'scale'
281-
if "norm" in renamed_pt_key and ("weight" in renamed_pt_key or "kernel" in renamed_pt_key):
282-
# Exclude 'norm2' if it's already handled, or specific dense layers that might be named norm (unlikely)
283-
if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key:
284-
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
285-
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
277+
# Dense Layers (net_0, net_2) -> force 'kernel'
278+
elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key:
279+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
280+
281+
# 5. Fix for 'norm_added_q' which showed up in your missing keys list
282+
# The error said 'kernel' was missing, implying this specific norm might act like a dense layer
283+
# OR it's a standard norm mismatch. We ensure it maps correctly to 'scale' first.
284+
if "norm_added" in renamed_pt_key:
285+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
286+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
286287
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
287288
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
288289
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

0 commit comments

Comments
 (0)