Skip to content

Commit 22b9ed6

Browse files
committed
wan_utils.py fixed
1 parent c692361 commit 22b9ed6

1 file changed

Lines changed: 16 additions & 24 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -256,38 +256,30 @@ def load_base_wan_transformer(
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258258
if "image_embedder" in renamed_pt_key:
259-
# 1. Handle Layer 0: "net.0" -> "net_0"
260-
# Source: ...ff.net.0.proj.weight
261-
# We ONLY change net.0 to net_0. We KEEP .proj because JAX wants it.
262-
if "net.0" in renamed_pt_key:
263-
renamed_pt_key = renamed_pt_key.replace("net.0", "net_0")
264259

265-
# 2. Handle Layer 2: "net.2" -> "net_2.proj"
266-
# Source: ...ff.net.2.weight (No proj in file)
267-
# We ADD .proj because JAX expects symmetry with Layer 0.
260+
# 1. FIX net_0 (Source has '.proj', Target does NOT)
261+
# Source: ...ff.net.0.proj.weight -> Target: ...ff.net_0.kernel
262+
if "net.0.proj" in renamed_pt_key:
263+
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
264+
# FORCE 'weight' -> 'kernel' for this dense layer
265+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
266+
267+
# 2. FIX net_2 (Source has NO '.proj', Target likely NO '.proj' for symmetry)
268+
# Source: ...ff.net.2.weight -> Target: ...ff.net_2.kernel
268269
elif "net.2" in renamed_pt_key:
269-
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2.proj")
270+
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
271+
# FORCE 'weight' -> 'kernel' for this dense layer
272+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
270273

271-
# 3. Handle Norm1: "norm1" -> "norm1.layer_norm"
274+
# 3. FIX Norm1 Nesting
275+
# Source: ...norm1.weight -> Target: ...norm1.layer_norm.scale
272276
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
273277

274-
# 4. Fix Parameter Names (Critical Step)
275-
# Norms (norm1, norm2) -> force 'scale'
278+
# 4. FIX Norm Parameter Names (Scale vs Weight)
276279
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
277280
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
281+
# Handle case where rename_key might have already turned it into kernel
278282
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
279-
# Dense Layers (net_0, net_2) -> force 'kernel'
280-
# This ensures that even if rename_key left it as 'weight', we force it to 'kernel'
281-
elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key:
282-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
283-
284-
# 5. Global Norm Fix (Fixes 'norm_added_q', etc.)
285-
# If JAX complained about missing 'kernel' for these, we respect that default.
286-
# If it complains about missing 'scale', we can uncomment the lines below.
287-
# Generally, JAX LayerNorms want 'scale'.
288-
if "norm_added" in renamed_pt_key:
289-
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
290-
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
291283
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
292284
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
293285
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

0 commit comments

Comments
 (0)