@@ -256,37 +256,35 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- # 1. Handle Layer 0: PyTorch "net.0.proj" -> JAX "net_0"
260- # The JAX model likely defines net_0 as a Dense layer directly, not a block with a 'proj' sub-layer.
261- # We must strip ".proj" out.
262- if "net.0.proj" in renamed_pt_key :
263- renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
264- elif "net.0" in renamed_pt_key :
265- renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0" )
266-
267- # 2. Handle Layer 2: PyTorch "net.2" -> JAX "net_2"
268- # Do NOT add ".proj" here unless your JAX definition explicitly has it.
269- # The shape mismatch error suggests the model found 'net_2' but the shape was wrong.
270- if "net.2" in renamed_pt_key :
271- renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
272-
273- # 3. Handle Norms: "norm1" -> "norm1.layer_norm" (Keeping your logic if JAX uses this structure)
274- # If standard Flax LayerNorm is used, usually just "norm1" is sufficient.
275- # Ensure "weight" becomes "scale".
276- if "norm" in renamed_pt_key :
277- renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
278- renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
279- renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
280-
281- # 4. Handle Dense Layers (Kernels) and Fix Shape Mismatch
282- if "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
283- # Rename weight to kernel
284- if "weight" in renamed_pt_key :
285- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
286-
287- # 5. Fix for 'norm_added_q' which showed up in your missing keys list
288- # The error said 'kernel' was missing, implying this specific norm might act like a dense layer
289- # OR it's a standard norm mismatch. We ensure it maps correctly to 'scale' first.
259+ # 1. Handle Layer 0: "net.0" -> "net_0"
260+ # Source: ...ff.net.0.proj.weight
261+ # We ONLY change net.0 to net_0. We KEEP .proj because JAX wants it.
262+ if "net.0" in renamed_pt_key :
263+ renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0" )
264+
265+ # 2. Handle Layer 2: "net.2" -> "net_2.proj"
266+ # Source: ...ff.net.2.weight (No proj in file)
267+ # We ADD .proj because JAX expects symmetry with Layer 0.
268+ elif "net.2" in renamed_pt_key :
269+ renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2.proj" )
270+
271+ # 3. Handle Norm1: "norm1" -> "norm1.layer_norm"
272+ renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
273+
274+ # 4. Fix Parameter Names (Critical Step)
275+ # Norms (norm1, norm2) -> force 'scale'
276+ if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
277+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
278+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
279+ # Dense Layers (net_0, net_2) -> force 'kernel'
280+ # This ensures that even if rename_key left it as 'weight', we force it to 'kernel'
281+ elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
282+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
283+
284+ # 5. Global Norm Fix (Fixes 'norm_added_q', etc.)
285+ # If JAX complained about missing 'kernel' for these, we respect that default.
286+ # If it complains about missing 'scale', we can uncomment the lines below.
287+ # Generally, JAX LayerNorms want 'scale'.
290288 if "norm_added" in renamed_pt_key :
291289 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
292290 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
0 commit comments