@@ -256,16 +256,29 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- if "net.0" in renamed_pt_key :
260- renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0.proj" )
261- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
259+ # 1. Handle Layer 0 (HAS .proj in PyTorch -> Remove it for JAX)
260+ # Source: condition_embedder.image_embedder.ff.net.0.proj.weight
261+ # Target: condition_embedder.image_embedder.ff.net_0.kernel
262+ if "net.0.proj" in renamed_pt_key :
263+ renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
264+
265+ # 2. Handle Layer 2 (NO .proj in PyTorch -> Standard rename)
266+ # Source: condition_embedder.image_embedder.ff.net.2.weight
267+ # Target: condition_embedder.image_embedder.ff.net_2.kernel
262268 elif "net.2" in renamed_pt_key :
263- renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2.proj" )
264- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
269+ renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
270+
271+ # 3. Fix Norm1 (Add .layer_norm wrapper)
265272 renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
273+
274+ # 4. Fix Norm Parameter Names (weight/kernel -> scale)
266275 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
267276 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
268277 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
278+
279+ # 5. Ensure Dense Weights use 'kernel' (Fixes weight->kernel mapping if missed)
280+ if "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
281+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
269282 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
270283 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
271284 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
0 commit comments