File tree Expand file tree Collapse file tree
src/maxdiffusion/models/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -275,9 +275,11 @@ def load_base_wan_transformer(
275275 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
276276 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
277277
278- if "norm_added" in renamed_pt_key :
279- if "attn2_norm_added" in renamed_pt_key :
280- renamed_pt_key = renamed_pt_key .replace ("attn2_norm_added" , "attn2.norm_added" )
278+ if "norm_added_q" in renamed_pt_key :
279+ # 1. Restore the dot before norm_added_q so JAX sees it as a submodule
280+ renamed_pt_key = renamed_pt_key .replace ("attn2_norm_added_q" , "attn2.norm_added_q" )
281+
282+ # 2. Force 'weight' -> 'scale' (JAX requirement)
281283 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
282284 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
283285 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
You can’t perform that action at this time.
0 commit comments