@@ -256,33 +256,34 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- # 1. Handle Layer 0: "net.0" -> "net_0" (PRESERVE .proj)
260- # Source: ...ff. net.0.proj.weight -> Target: ...ff. net_0.proj.kernel
259+ # 1. Handle Layer 0: PyTorch has "net.0.proj " -> JAX wants "net_0 .proj"
260+ # We just replace the separator " net.0" -> " net_0"
261261 if "net.0" in renamed_pt_key :
262262 renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0" )
263- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
264263
265- # 2. Handle Layer 2: "net.2" -> "net_2"
266- # Source: ...ff.net.2.weight -> Target: ...ff.net_2.kernel
264+ # 2. Handle Layer 2: PyTorch has "net.2" (NO proj) -> JAX likely wants "net_2.proj "
265+ # We force the addition of ".proj" to match the symmetric JAX structure
267266 elif "net.2" in renamed_pt_key :
268- renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
269- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
267+ renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2.proj" )
270268
271269 # 3. Handle Norm1: "norm1" -> "norm1.layer_norm"
272270 renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
273271
274- # 4. Force Norms to use "scale"
272+ # 4. Fix Parameter Names:
273+ # Norms (norm1, norm2) -> force 'scale'
275274 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
276275 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
277276 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
278-
279- # 5. Global Norm Fix (Fixes 'norm_added_q', 'norm_k', etc.)
280- # Any key containing 'norm' that ends in 'weight'/'kernel' should likely be 'scale'
281- if "norm" in renamed_pt_key and ("weight" in renamed_pt_key or "kernel" in renamed_pt_key ):
282- # Exclude 'norm2' if it's already handled, or specific dense layers that might be named norm (unlikely)
283- if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key :
284- renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
285- renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
277+ # Dense Layers (net_0, net_2) -> force 'kernel'
278+ elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
279+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
280+
281+ # 5. Fix for 'norm_added_q' which showed up in your missing keys list
282+ # The error said 'kernel' was missing, implying this specific norm might act like a dense layer
283+ # OR it's a standard norm mismatch. We ensure it maps correctly to 'scale' first.
284+ if "norm_added" in renamed_pt_key :
285+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
286+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
286287 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
287288 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
288289 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
0 commit comments