File tree Expand file tree Collapse file tree
src/maxdiffusion/models/wan Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -276,12 +276,17 @@ def load_base_wan_transformer(
276276 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
277277
278278 if "norm_added_q" in renamed_pt_key :
279- # 1. Restore the hierarchy: Convert flattened "attn2_norm_added_q" back to "attn2.norm_added_q"
280- # We use a broad replace to catch it regardless of what prefix (blocks_X_) is before it.
279+ # 1. Structural Fix: Ensure 'attn2' is separated from the block index
280+ # 'blocks_0_attn2' -> 'blocks_0.attn2'
281+ if "_attn2" in renamed_pt_key :
282+ renamed_pt_key = renamed_pt_key .replace ("_attn2" , ".attn2" )
283+
284+ # 2. Restore hierarchy for the norm itself
285+ # 'attn2_norm_added_q' -> 'attn2.norm_added_q'
281286 if "attn2_norm_added_q" in renamed_pt_key :
282287 renamed_pt_key = renamed_pt_key .replace ("attn2_norm_added_q" , "attn2.norm_added_q" )
283288
284- # 2 . Force 'weight' -> 'scale' (Crucial: JAX explicitly asks for 'scale')
289+ # 3 . Force 'weight' -> 'scale'
285290 renamed_pt_key = renamed_pt_key .replace ("weight" , "scale" )
286291 renamed_pt_key = renamed_pt_key .replace ("kernel" , "scale" )
287292 renamed_pt_key = renamed_pt_key .replace ("blocks_" , "blocks." )
You can’t perform that action at this time.
0 commit comments