@@ -256,41 +256,30 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- # 1. Fix Shape Mismatch: Transpose Dense Layers
260- # CHECK FOR BOTH VERSIONS ("net.0"/"net_0" and "net.2"/"net_2")
261- # This ensures the transpose happens regardless of what rename_key did.
262259 if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
263260 "net.2" in renamed_pt_key or "net_2" in renamed_pt_key :
264261 tensor = tensor .T
265-
266- # 2. FIX net_0: Strip .proj (Handle both raw and renamed keys)
267262 if "net.0.proj" in renamed_pt_key :
268263 renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
269264 renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
270265 elif "net_0.proj" in renamed_pt_key :
271266 renamed_pt_key = renamed_pt_key .replace ("net_0.proj" , "net_0" )
272267 renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
273-
274- # 3. FIX net_2: Ensure naming and kernel (Handle both raw and renamed keys)
275268 if "net.2" in renamed_pt_key :
276269 renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
277270 renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
278271 elif "net_2" in renamed_pt_key :
279272 renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
280-
281- # 4. FIX Norms: Add .layer_norm and force 'scale'
282273 renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
283274 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
284275 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
285276 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
286277
287- # --- FIX START: Missing Key 'norm_added_q' Fix ---
288278 if "norm_added" in renamed_pt_key :
289- renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
290- renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
291-
292279 if "attn2_norm_added" in renamed_pt_key :
293280 renamed_pt_key = renamed_pt_key .replace ("attn2_norm_added" , "attn2.norm_added" )
281+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
282+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
294283 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
295284 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
296285 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
0 commit comments