@@ -256,38 +256,30 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- # 1. Handle Layer 0: "net.0" -> "net_0"
260- # Source: ...ff.net.0.proj.weight
261- # We ONLY change net.0 to net_0. We KEEP .proj because JAX wants it.
262- if "net.0" in renamed_pt_key :
263- renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0" )
264259
265- # 2. Handle Layer 2: "net.2" -> "net_2.proj"
266- # Source: ...ff.net.2.weight (No proj in file)
267- # We ADD .proj because JAX expects symmetry with Layer 0.
260+ # 1. FIX net_0 (Source has '.proj', Target does NOT)
261+ # Source: ...ff.net.0.proj.weight -> Target: ...ff.net_0.kernel
262+ if "net.0.proj" in renamed_pt_key :
263+ renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
264+ # FORCE 'weight' -> 'kernel' for this dense layer
265+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
266+
267+ # 2. FIX net_2 (Source has NO '.proj', Target likely NO '.proj' for symmetry)
268+ # Source: ...ff.net.2.weight -> Target: ...ff.net_2.kernel
268269 elif "net.2" in renamed_pt_key :
269- renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2.proj" )
270+ renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
271+ # FORCE 'weight' -> 'kernel' for this dense layer
272+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
270273
271- # 3. Handle Norm1: "norm1" -> "norm1.layer_norm"
274+ # 3. FIX Norm1 Nesting
275+ # Source: ...norm1.weight -> Target: ...norm1.layer_norm.scale
272276 renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
273277
274- # 4. Fix Parameter Names (Critical Step)
275- # Norms (norm1, norm2) -> force 'scale'
278+ # 4. FIX Norm Parameter Names (Scale vs Weight)
276279 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
277280 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
281+ # Handle case where rename_key might have already turned it into kernel
278282 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
279- # Dense Layers (net_0, net_2) -> force 'kernel'
280- # This ensures that even if rename_key left it as 'weight', we force it to 'kernel'
281- elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
282- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
283-
284- # 5. Global Norm Fix (Fixes 'norm_added_q', etc.)
285- # If JAX complained about missing 'kernel' for these, we respect that default.
286- # If it complains about missing 'scale', we can uncomment the lines below.
287- # Generally, JAX LayerNorms want 'scale'.
288- if "norm_added" in renamed_pt_key :
289- renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
290- renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
291283 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
292284 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
293285 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
0 commit comments