Skip to content

Commit d052901

Browse files
committed
missing keys error
1 parent fc0c4db commit d052901

1 file changed

Lines changed: 2 additions & 13 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -256,41 +256,30 @@ def load_base_wan_transformer(
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258258
if "image_embedder" in renamed_pt_key:
259-
# 1. Fix Shape Mismatch: Transpose Dense Layers
260-
# CHECK FOR BOTH VERSIONS ("net.0"/"net_0" and "net.2"/"net_2")
261-
# This ensures the transpose happens regardless of what rename_key did.
262259
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
263260
"net.2" in renamed_pt_key or "net_2" in renamed_pt_key:
264261
tensor = tensor.T
265-
266-
# 2. FIX net_0: Strip .proj (Handle both raw and renamed keys)
267262
if "net.0.proj" in renamed_pt_key:
268263
renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0")
269264
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
270265
elif "net_0.proj" in renamed_pt_key:
271266
renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0")
272267
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
273-
274-
# 3. FIX net_2: Ensure naming and kernel (Handle both raw and renamed keys)
275268
if "net.2" in renamed_pt_key:
276269
renamed_pt_key = renamed_pt_key.replace("net.2", "net_2")
277270
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
278271
elif "net_2" in renamed_pt_key:
279272
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
280-
281-
# 4. FIX Norms: Add .layer_norm and force 'scale'
282273
renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm")
283274
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
284275
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
285276
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
286277

287-
# --- FIX START: Missing Key 'norm_added_q' Fix ---
288278
if "norm_added" in renamed_pt_key:
289-
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
290-
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
291-
292279
if "attn2_norm_added" in renamed_pt_key:
293280
renamed_pt_key = renamed_pt_key.replace("attn2_norm_added", "attn2.norm_added")
281+
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
282+
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
294283
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
295284
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
296285
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")

0 commit comments

Comments
 (0)