Skip to content

Commit 0558ec1

Browse files
committed
missing keys error
1 parent 96bf827 commit 0558ec1

1 file changed

Lines changed: 15 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ def load_base_wan_transformer(
255255
del flattened_dict
256256
for pt_key, tensor in tensors.items():
257257
renamed_pt_key = rename_key(pt_key)
258+
if "norm_added_q" in pt_key:
259+
debug_original = renamed_pt_key
258260
if "image_embedder" in renamed_pt_key:
259261
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
260262
"net.2" in renamed_pt_key or "net_2" in renamed_pt_key:
@@ -276,16 +278,27 @@ def load_base_wan_transformer(
276278
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
277279

278280
if "norm_added_q" in renamed_pt_key:
279-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
280-
tensor = tensor.T
281+
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
282+
tensor = tensor.T
281283
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
284+
282285

283286
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
284287
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
285288
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
286289
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
287290
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
288291
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
292+
if "norm_added_q" in pt_key:
293+
print(f"DEBUG REPORT for {pt_key}:")
294+
print(f" 1. After rename_key : {debug_original}")
295+
print(f" 2. Final Key String : {renamed_pt_key}")
296+
297+
# Test parsing
298+
pt_tuple_key = tuple(renamed_pt_key.split("."))
299+
flax_key, _ = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
300+
print(f" 3. Parsed Flax Key : {flax_key}")
301+
print("-" * 20)
289302
pt_tuple_key = tuple(renamed_pt_key.split("."))
290303
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
291304
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)

0 commit comments

Comments
 (0)