Skip to content

Commit a222767

Browse files
committed
dim mismatch fix
1 parent dcf8afd commit a222767

1 file changed

Lines changed: 15 additions & 31 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -256,38 +256,22 @@ def load_base_wan_transformer(
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258258
if "image_embedder" in renamed_pt_key:
259-
260-
# 1. FIX net_0: Source has '.proj', Target (JAX) does NOT.
261-
# We check for BOTH "net.0.proj" (raw) and "net_0.proj" (renamed) to be safe.
262-
if "net.0.proj" in renamed_pt_key:
263-
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
264-
elif "net_0.proj" in renamed_pt_key: # <--- THIS IS THE MISSING LINK
265-
renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0")
266-
267-
# 2. FIX net_2: Ensure consistent naming (net.2 -> net_2)
268-
# JAX wants 'net_2', source is 'net.2'.
269-
if "net.2" in renamed_pt_key:
270-
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
271-
272-
# 3. FIX Norm1: Add .layer_norm wrapper
273-
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
274-
275-
# 4. FIX Parameter Names (Weight -> Kernel/Scale)
276-
# Force 'scale' for all norms
277-
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
278-
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
279-
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
280-
281-
# Force 'kernel' for dense layers (net_0 and net_2)
282-
# We check 'net_0' because we just renamed it above.
283-
elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key:
284-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
285-
286-
# 5. Global Norm Fix (e.g. norm_added_q)
259+
if "net.0.proj" in renamed_pt_key:
260+
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
261+
elif "net_0.proj" in renamed_pt_key:
262+
renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0")
263+
if "net.2" in renamed_pt_key:
264+
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
265+
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
266+
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
267+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
268+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
269+
elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key:
270+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
287271
if "norm" in renamed_pt_key and "image_embedder" not in renamed_pt_key:
288-
if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key:
289-
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
290-
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
272+
if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key:
273+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
274+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
291275
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
292276
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
293277
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

0 commit comments

Comments
 (0)