@@ -256,27 +256,38 @@ def load_base_wan_transformer(
256256 for pt_key , tensor in tensors .items ():
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259- # 1. Handle Layer 0: PyTorch has "net.0.proj" -> JAX wants "net_0.proj"
260- # We just replace the separator "net.0" -> "net_0"
261- if "net.0" in renamed_pt_key :
262- renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0" )
263-
264- # 2. Handle Layer 2: PyTorch has "net.2" (NO proj) -> JAX likely wants "net_2.proj"
265- # We force the addition of ".proj" to match the symmetric JAX structure
266- elif "net.2" in renamed_pt_key :
267- renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2.proj" )
268-
269- # 3. Handle Norm1: "norm1" -> "norm1.layer_norm"
270- renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
271-
272- # 4. Fix Parameter Names:
273- # Norms (norm1, norm2) -> force 'scale'
274- if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
275- renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
276- renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
277- # Dense Layers (net_0, net_2) -> force 'kernel'
278- elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
279- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
259+ # 1. Handle Layer 0: PyTorch "net.0.proj" -> JAX "net_0"
260+ # The JAX model likely defines net_0 as a Dense layer directly, not a block with a 'proj' sub-layer.
261+ # We must strip ".proj" out.
262+ if "net.0.proj" in renamed_pt_key :
263+ renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
264+ elif "net.0" in renamed_pt_key :
265+ renamed_pt_key = renamed_pt_key .replace ("net.0" , "net_0" )
266+
267+ # 2. Handle Layer 2: PyTorch "net.2" -> JAX "net_2"
268+ # Do NOT add ".proj" here unless your JAX definition explicitly has it.
269+ # The shape mismatch error suggests the model found 'net_2' but the shape was wrong.
270+ if "net.2" in renamed_pt_key :
271+ renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
272+
273+ # 3. Handle Norms: "norm1" -> "norm1.layer_norm" (Keeping your logic if JAX uses this structure)
274+ # If standard Flax LayerNorm is used, usually just "norm1" is sufficient.
275+ # Ensure "weight" becomes "scale".
276+ if "norm" in renamed_pt_key :
277+ renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
278+ renamed_pt_key = renamed_pt_key .replace ("norm2" , "norm2.layer_norm" )
279+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
280+
281+ # 4. Handle Dense Layers (Kernels) and Fix Shape Mismatch
282+ if "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
283+ # Rename weight to kernel
284+ if "weight" in renamed_pt_key :
285+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
286+
287+ # CRITICAL FIX: Transpose the weights
288+ # PyTorch Linear is (Out, In), JAX Dense is (In, Out).
289+ # Ensure 'pt_tensor' is the variable holding your weight tensor.
290+ pt_tensor = pt_tensor .T
280291
281292 # 5. Fix for 'norm_added_q' which showed up in your missing keys list
282293 # The error said 'kernel' was missing, implying this specific norm might act like a dense layer
0 commit comments