Skip to content

Commit 4c954a5

Browse files
committed
missing keys error
1 parent cf139a0 commit 4c954a5

1 file changed

Lines changed: 2 additions & 13 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,6 @@ def load_base_wan_transformer(
262262
tensor = tensor.T
263263
norm_added_q_buffer[block_idx] = tensor
264264
continue
265-
if "norm_added_q" in pt_key:
266-
debug_original = renamed_pt_key
267265
if "image_embedder" in renamed_pt_key:
268266
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
269267
"net.2" in renamed_pt_key or "net_2" in renamed_pt_key:
@@ -288,17 +286,8 @@ def load_base_wan_transformer(
288286
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
289287
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
290288
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
291-
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
292-
if "norm_added_q" in pt_key:
293-
print(f"DEBUG REPORT for {pt_key}:")
294-
print(f" 1. After rename_key : {debug_original}")
295-
print(f" 2. Final Key String : {renamed_pt_key}")
296-
297-
# Test parsing
298-
pt_tuple_key = tuple(renamed_pt_key.split("."))
299-
flax_key, _ = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
300-
print(f" 3. Parsed Flax Key : {flax_key}")
301-
print("-" * 20)
289+
if "norm2.layer_norm" not in renamed_pt_key:
290+
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
302291
pt_tuple_key = tuple(renamed_pt_key.split("."))
303292
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
304293
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)

0 commit comments

Comments
 (0)