@@ -256,38 +256,22 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259-
260- # 1. FIX net_0: Source has '.proj', Target (JAX) does NOT.
261- # We check for BOTH "net.0.proj" (raw) and "net_0.proj" (renamed) to be safe.
262- if "net.0.proj" in renamed_pt_key :
263- renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
264- elif "net_0.proj" in renamed_pt_key : # <--- THIS IS THE MISSING LINK
265- renamed_pt_key = renamed_pt_key .replace ("net_0.proj" , "net_0" )
266-
267- # 2. FIX net_2: Ensure consistent naming (net.2 -> net_2)
268- # JAX wants 'net_2', source is 'net.2'.
269- if "net.2" in renamed_pt_key :
270- renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
271-
272- # 3. FIX Norm1: Add .layer_norm wrapper
273- renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
274-
275- # 4. FIX Parameter Names (Weight -> Kernel/Scale)
276- # Force 'scale' for all norms
277- if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
278- renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
279- renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
280-
281- # Force 'kernel' for dense layers (net_0 and net_2)
282- # We check 'net_0' because we just renamed it above.
283- elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
284- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
285-
286- # 5. Global Norm Fix (e.g. norm_added_q)
259+ if "net.0.proj" in renamed_pt_key :
260+ renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
261+ elif "net_0.proj" in renamed_pt_key :
262+ renamed_pt_key = renamed_pt_key .replace ("net_0.proj" , "net_0" )
263+ if "net.2" in renamed_pt_key :
264+ renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
265+ renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
266+ if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
267+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
268+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
269+ elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
270+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
287271 if "norm" in renamed_pt_key and "image_embedder" not in renamed_pt_key :
288- if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key :
289- renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
290- renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
272+ if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key :
273+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
274+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
291275 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
292276 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
293277 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
0 commit comments