@@ -256,29 +256,33 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- # 1. Handle Layer 0 (HAS .proj in PyTorch -> Remove it for JAX )
260- # Source: condition_embedder.image_embedder. ff.net.0.proj.weight
261- # Target: condition_embedder.image_embedder.ff.net_0.kernel
262- if "net.0.proj" in renamed_pt_key :
263- renamed_pt_key = renamed_pt_key .replace ("net.0.proj " , "net_0 " )
259+ # 1. Handle Layer 0: "net.0" -> "net_0" (PRESERVE .proj )
260+ # Source: ... ff.net.0.proj.weight -> Target: ...ff.net_0.proj.kernel
261+ if "net.0" in renamed_pt_key :
262+ renamed_pt_key = renamed_pt_key . replace ( "net.0" , "net_0" )
263+ renamed_pt_key = renamed_pt_key .replace ("weight " , "kernel " )
264264
265- # 2. Handle Layer 2 (NO .proj in PyTorch -> Standard rename)
266- # Source: condition_embedder.image_embedder.ff.net.2.weight
267- # Target: condition_embedder.image_embedder.ff.net_2.kernel
265+ # 2. Handle Layer 2: "net.2" -> "net_2"
266+ # Source: ...ff.net.2.weight -> Target: ...ff.net_2.kernel
268267 elif "net.2" in renamed_pt_key :
269268 renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
269+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
270270
271- # 3. Fix Norm1 (Add .layer_norm wrapper)
271+ # 3. Handle Norm1: "norm1" -> "norm1 .layer_norm"
272272 renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
273273
274- # 4. Fix Norm Parameter Names (weight/kernel -> scale)
274+ # 4. Force Norms to use " scale"
275275 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
276276 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
277277 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
278-
279- # 5. Ensure Dense Weights use 'kernel' (Fixes weight->kernel mapping if missed)
280- if "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
281- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
278+
279+ # 5. Global Norm Fix (Fixes 'norm_added_q', 'norm_k', etc.)
280+ # Any key containing 'norm' that ends in 'weight'/'kernel' should likely be 'scale'
281+ if "norm" in renamed_pt_key and ("weight" in renamed_pt_key or "kernel" in renamed_pt_key ):
282+ # Exclude 'norm2' if it's already handled, or specific dense layers that might be named norm (unlikely)
283+ if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key :
284+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
285+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
282286 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
283287 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
284288 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
0 commit comments