Skip to content

Commit d021c9b

Browse files
committed
wan_utils.py fixed
1 parent bbbbafe commit d021c9b

1 file changed

Lines changed: 18 additions & 14 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,29 +256,33 @@ def load_base_wan_transformer(
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258258
if "image_embedder" in renamed_pt_key:
259-
# 1. Handle Layer 0 (HAS .proj in PyTorch -> Remove it for JAX)
260-
# Source: condition_embedder.image_embedder.ff.net.0.proj.weight
261-
# Target: condition_embedder.image_embedder.ff.net_0.kernel
262-
if "net.0.proj" in renamed_pt_key:
263-
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
259+
# 1. Handle Layer 0: "net.0" -> "net_0" (PRESERVE .proj)
260+
# Source: ...ff.net.0.proj.weight -> Target: ...ff.net_0.proj.kernel
261+
if "net.0" in renamed_pt_key:
262+
renamed_pt_key = renamed_pt_key.replace("net.0", "net_0")
263+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
264264

265-
# 2. Handle Layer 2 (NO .proj in PyTorch -> Standard rename)
266-
# Source: condition_embedder.image_embedder.ff.net.2.weight
267-
# Target: condition_embedder.image_embedder.ff.net_2.kernel
265+
# 2. Handle Layer 2: "net.2" -> "net_2"
266+
# Source: ...ff.net.2.weight -> Target: ...ff.net_2.kernel
268267
elif "net.2" in renamed_pt_key:
269268
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
269+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
270270

271-
# 3. Fix Norm1 (Add .layer_norm wrapper)
271+
# 3. Handle Norm1: "norm1" -> "norm1.layer_norm"
272272
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
273273

274-
# 4. Fix Norm Parameter Names (weight/kernel -> scale)
274+
# 4. Force Norms to use "scale"
275275
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
276276
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
277277
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
278-
279-
# 5. Ensure Dense Weights use 'kernel' (Fixes weight->kernel mapping if missed)
280-
if "net_0" in renamed_pt_key or "net_2" in renamed_pt_key:
281-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
278+
279+
# 5. Global Norm Fix (Fixes 'norm_added_q', 'norm_k', etc.)
280+
# Any key containing 'norm' that ends in 'weight'/'kernel' should likely be 'scale'
281+
if "norm" in renamed_pt_key and ("weight" in renamed_pt_key or "kernel" in renamed_pt_key):
282+
# Exclude 'norm2' if it's already handled, or specific dense layers that might be named norm (unlikely)
283+
if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key:
284+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
285+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
282286
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
283287
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
284288
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

0 commit comments

Comments
 (0)