@@ -256,22 +256,41 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- if "net.0.proj" in renamed_pt_key :
260- renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
261- elif "net_0.proj" in renamed_pt_key :
262- renamed_pt_key = renamed_pt_key .replace ("net_0.proj" , "net_0" )
263- if "net.2" in renamed_pt_key :
264- renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
265- renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
266- if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
267- renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
268- renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
269- elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
270- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
271- if "norm" in renamed_pt_key and "image_embedder" not in renamed_pt_key :
272- if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key :
273- renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
274- renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
259+ # 1. Fix Shape Mismatch: Transpose Dense Layers
260+ # CHECK FOR BOTH VERSIONS ("net.0"/"net_0" and "net.2"/"net_2")
261+ # This ensures the transpose happens regardless of what rename_key did.
262+ if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
263+ "net.2" in renamed_pt_key or "net_2" in renamed_pt_key :
264+ tensor = tensor .t () # <--- This will now correctly execute
265+
266+ # 2. FIX net_0: Strip .proj (Handle both raw and renamed keys)
267+ if "net.0.proj" in renamed_pt_key :
268+ renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
269+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
270+ elif "net_0.proj" in renamed_pt_key :
271+ renamed_pt_key = renamed_pt_key .replace ("net_0.proj" , "net_0" )
272+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
273+
274+ # 3. FIX net_2: Ensure naming and kernel (Handle both raw and renamed keys)
275+ if "net.2" in renamed_pt_key :
276+ renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
277+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
278+ elif "net_2" in renamed_pt_key :
279+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
280+
281+ # 4. FIX Norms: Add .layer_norm and force 'scale'
282+ renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
283+ if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
284+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
285+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
286+
287+ # --- FIX START: Missing Key 'norm_added_q' Fix ---
288+ if "norm_added" in renamed_pt_key :
289+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
290+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
291+
292+ if "attn2_norm_added" in renamed_pt_key :
293+ renamed_pt_key = renamed_pt_key .replace ("attn2_norm_added" , "attn2.norm_added" )
275294 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
276295 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
277296 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
0 commit comments