Skip to content

Commit c692361

Browse files
committed
wan_utils.py fixed
1 parent 1e79c0b commit c692361

1 file changed

Lines changed: 29 additions & 31 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -256,37 +256,35 @@ def load_base_wan_transformer(
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258258
if "image_embedder" in renamed_pt_key:
259-
# 1. Handle Layer 0: PyTorch "net.0.proj" -> JAX "net_0"
260-
# The JAX model likely defines net_0 as a Dense layer directly, not a block with a 'proj' sub-layer.
261-
# We must strip ".proj" out.
262-
if "net.0.proj" in renamed_pt_key:
263-
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
264-
elif "net.0" in renamed_pt_key:
265-
renamed_pt_key = renamed_pt_key.replace("net.0", "net_0")
266-
267-
# 2. Handle Layer 2: PyTorch "net.2" -> JAX "net_2"
268-
# Do NOT add ".proj" here unless your JAX definition explicitly has it.
269-
# The shape mismatch error suggests the model found 'net_2' but the shape was wrong.
270-
if "net.2" in renamed_pt_key:
271-
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
272-
273-
# 3. Handle Norms: "norm1" -> "norm1.layer_norm" (Keeping your logic if JAX uses this structure)
274-
# If standard Flax LayerNorm is used, usually just "norm1" is sufficient.
275-
# Ensure "weight" becomes "scale".
276-
if "norm" in renamed_pt_key:
277-
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
278-
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
279-
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
280-
281-
# 4. Handle Dense Layers (Kernels) and Fix Shape Mismatch
282-
if "net_0" in renamed_pt_key or "net_2" in renamed_pt_key:
283-
# Rename weight to kernel
284-
if "weight" in renamed_pt_key:
285-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
286-
287-
# 5. Fix for 'norm_added_q' which showed up in your missing keys list
288-
# The error said 'kernel' was missing, implying this specific norm might act like a dense layer
289-
# OR it's a standard norm mismatch. We ensure it maps correctly to 'scale' first.
259+
# 1. Handle Layer 0: "net.0" -> "net_0"
260+
# Source: ...ff.net.0.proj.weight
261+
# We ONLY change net.0 to net_0. We KEEP .proj because JAX wants it.
262+
if "net.0" in renamed_pt_key:
263+
renamed_pt_key = renamed_pt_key.replace("net.0", "net_0")
264+
265+
# 2. Handle Layer 2: "net.2" -> "net_2.proj"
266+
# Source: ...ff.net.2.weight (No proj in file)
267+
# We ADD .proj because JAX expects symmetry with Layer 0.
268+
elif "net.2" in renamed_pt_key:
269+
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2.proj")
270+
271+
# 3. Handle Norm1: "norm1" -> "norm1.layer_norm"
272+
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
273+
274+
# 4. Fix Parameter Names (Critical Step)
275+
# Norms (norm1, norm2) -> force 'scale'
276+
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
277+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
278+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
279+
# Dense Layers (net_0, net_2) -> force 'kernel'
280+
# This ensures that even if rename_key left it as 'weight', we force it to 'kernel'
281+
elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key:
282+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
283+
284+
# 5. Global Norm Fix (Fixes 'norm_added_q', etc.)
285+
# If JAX complained about missing 'kernel' for these, we respect that default.
286+
# If it complains about missing 'scale', we can uncomment the lines below.
287+
# Generally, JAX LayerNorms want 'scale'.
290288
if "norm_added" in renamed_pt_key:
291289
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
292290
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")

0 commit comments

Comments
 (0)