File tree Expand file tree Collapse file tree
src/maxdiffusion/models/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -276,10 +276,12 @@ def load_base_wan_transformer(
276276 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
277277
278278 if "norm_added_q" in renamed_pt_key :
279- # 1. Restore the dot before norm_added_q so JAX sees it as a submodule
280- renamed_pt_key = renamed_pt_key .replace ("attn2_norm_added_q" , "attn2.norm_added_q" )
279+ # 1. Restore the hierarchy: Convert flattened "attn2_norm_added_q" back to "attn2.norm_added_q"
280+ # We use a broad replace to catch it regardless of what prefix (blocks_X_) is before it.
281+ if "attn2_norm_added_q" in renamed_pt_key :
282+ renamed_pt_key = renamed_pt_key .replace ("attn2_norm_added_q" , "attn2.norm_added_q" )
281283
282- # 2. Force 'weight' -> 'scale' (JAX requirement )
284+ # 2. Force 'weight' -> 'scale' (Crucial: JAX explicitly asks for 'scale' )
283285 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
284286 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
285287 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
You can’t perform that action at this time.
0 commit comments