Skip to content

Commit cf139a0

Browse files
committed
missing keys error
1 parent 0558ec1 commit cf139a0

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,15 @@ def load_base_wan_transformer(
253253
string_tuple = tuple([str(item) for item in key])
254254
random_flax_state_dict[string_tuple] = flattened_dict[key]
255255
del flattened_dict
256+
norm_added_q_buffer = {}
256257
for pt_key, tensor in tensors.items():
257258
renamed_pt_key = rename_key(pt_key)
259+
if "norm_added_q" in pt_key:
260+
parts = pt_key.split(".")
261+
block_idx = int(parts[1])
262+
tensor = tensor.T
263+
norm_added_q_buffer[block_idx] = tensor
264+
continue
258265
if "norm_added_q" in pt_key:
259266
debug_original = renamed_pt_key
260267
if "image_embedder" in renamed_pt_key:
@@ -276,13 +283,6 @@ def load_base_wan_transformer(
276283
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
277284
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
278285
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
279-
280-
if "norm_added_q" in renamed_pt_key:
281-
renamed_pt_key = renamed_pt_key.replace("weight", "kernel")
282-
tensor = tensor.T
283-
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
284-
285-
286286
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
287287
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
288288
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
@@ -302,6 +302,11 @@ def load_base_wan_transformer(
302302
pt_tuple_key = tuple(renamed_pt_key.split("."))
303303
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
304304
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
305+
if norm_added_q_buffer:
306+
sorted_tensors = [norm_added_q_buffer[i] for i in sorted(norm_added_q_buffer.keys())]
307+
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
308+
final_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
309+
flax_state_dict[final_key] = jax.device_put(stacked_tensor, device=cpu)
305310

306311
validate_flax_state_dict(eval_shapes, flax_state_dict)
307312
flax_state_dict = unflatten_dict(flax_state_dict)

0 commit comments

Comments
 (0)