@@ -257,29 +257,37 @@ def load_base_wan_transformer(
257257 renamed_pt_key = rename_key (pt_key )
258258 if "image_embedder" in renamed_pt_key :
259259
260- # 1. FIX net_0 ( Source has '.proj', Target does NOT)
261- # Source: ...ff. net.0.proj.weight -> Target: ...ff. net_0.kernel
260+ # 1. FIX net_0: Source has '.proj', Target (JAX) does NOT.
261+ # We check for BOTH " net.0.proj" (raw) and " net_0.proj" (renamed) to be safe.
262262 if "net.0.proj" in renamed_pt_key :
263263 renamed_pt_key = renamed_pt_key .replace ("net.0.proj" , "net_0" )
264- # FORCE 'weight' -> 'kernel' for this dense layer
265- renamed_pt_key = renamed_pt_key .replace ("weight " , "kernel " )
266-
267- # 2. FIX net_2 (Source has NO '.proj', Target likely NO '.proj' for symmetry )
268- # Source: ...ff.net.2.weight -> Target: ...ff.net_2.kernel
269- elif "net.2" in renamed_pt_key :
264+ elif "net_0.proj" in renamed_pt_key : # <--- THIS IS THE MISSING LINK
265+ renamed_pt_key = renamed_pt_key .replace ("net_0.proj " , "net_0 " )
266+
267+ # 2. FIX net_2: Ensure consistent naming (net.2 -> net_2 )
268+ # JAX wants 'net_2', source is 'net.2'.
269+ if "net.2" in renamed_pt_key :
270270 renamed_pt_key = renamed_pt_key .replace ("net.2" , "net_2" )
271- # FORCE 'weight' -> 'kernel' for this dense layer
272- renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
273-
274- # 3. FIX Norm1 Nesting
275- # Source: ...norm1.weight -> Target: ...norm1.layer_norm.scale
271+
272+ # 3. FIX Norm1: Add .layer_norm wrapper
276273 renamed_pt_key = renamed_pt_key .replace ("norm1" , "norm1.layer_norm" )
277274
278- # 4. FIX Norm Parameter Names (Scale vs Weight)
275+ # 4. FIX Parameter Names (Weight -> Kernel/Scale)
276+ # Force 'scale' for all norms
279277 if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key :
280278 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
281- # Handle case where rename_key might have already turned it into kernel
282279 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
280+
281+ # Force 'kernel' for dense layers (net_0 and net_2)
282+ # We check 'net_0' because we just renamed it above.
283+ elif "net_0" in renamed_pt_key or "net_2" in renamed_pt_key :
284+ renamed_pt_key = renamed_pt_key .replace ("weight" , "kernel" )
285+
286+ # 5. Global Norm Fix (e.g. norm_added_q)
287+ if "norm" in renamed_pt_key and "image_embedder" not in renamed_pt_key :
288+ if "norm_added" in renamed_pt_key or "norm_k" in renamed_pt_key or "norm_q" in renamed_pt_key :
289+ renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
290+ renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
283291 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
284292 renamed_pt_key = renamed_pt_key .replace (".scale_shift_table" , ".adaln_scale_shift_table" )
285293 renamed_pt_key = renamed_pt_key .replace ("to_out_0" , "proj_attn" )
0 commit comments