Skip to content

Commit dcf8afd

Browse files
committed
wan_utils.py fixed
1 parent 22b9ed6 commit dcf8afd

1 file changed

Lines changed: 23 additions & 15 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -257,29 +257,37 @@ def load_base_wan_transformer(
257257
renamed_pt_key = rename_key(pt_key)
258258
if "image_embedder" in renamed_pt_key:
259259

260-
# 1. FIX net_0 (Source has '.proj', Target does NOT)
261-
# Source: ...ff.net.0.proj.weight -> Target: ...ff.net_0.kernel
260+
# 1. FIX net_0: Source has '.proj', Target (JAX) does NOT.
261+
# We check for BOTH "net.0.proj" (raw) and "net_0.proj" (renamed) to be safe.
262262
if "net.0.proj" in renamed_pt_key:
263263
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
264-
# FORCE 'weight' -> 'kernel' for this dense layer
265-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
266-
267-
# 2. FIX net_2 (Source has NO '.proj', Target likely NO '.proj' for symmetry)
268-
# Source: ...ff.net.2.weight -> Target: ...ff.net_2.kernel
269-
elif "net.2" in renamed_pt_key:
264+
elif "net_0.proj" in renamed_pt_key: # <--- THIS IS THE MISSING LINK
265+
renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0")
266+
267+
# 2. FIX net_2: Ensure consistent naming (net.2 -> net_2)
268+
# JAX wants 'net_2', source is 'net.2'.
269+
if "net.2" in renamed_pt_key:
270270
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
271-
# FORCE 'weight' -> 'kernel' for this dense layer
272-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
273-
274-
# 3. FIX Norm1 Nesting
275-
# Source: ...norm1.weight -> Target: ...norm1.layer_norm.scale
271+
272+
# 3. FIX Norm1: Add .layer_norm wrapper
276273
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
277274

278-
# 4. FIX Norm Parameter Names (Scale vs Weight)
275+
# 4. FIX Parameter Names (Weight -> Kernel/Scale)
276+
# Force 'scale' for all norms
279277
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
280278
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
281-
# Handle case where rename_key might have already turned it into kernel
282279
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
280+
281+
# Force 'kernel' for dense layers (net_0 and net_2)
282+
# We check 'net_0' because we just renamed it above.
283+
elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key:
284+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
285+
286+
# 5. Global Norm Fix (e.g. norm_added_q)
287+
if "norm" in renamed_pt_key and "image_embedder" not in renamed_pt_key:
288+
if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key:
289+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
290+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
283291
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
284292
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
285293
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

0 commit comments

Comments
 (0)