Skip to content

Commit 635f6fc

Browse files
committed
missing key debug
1 parent f39796f commit 635f6fc

1 file changed

Lines changed: 24 additions & 10 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,19 +256,25 @@ def load_base_wan_transformer(
256256
norm_added_q_buffer = {}
257257
print(f"DEBUG: Total keys found in checkpoint: {len(tensors)}")
258258
for pt_key, tensor in tensors.items():
259-
if "norm_added_q" in pt_key:
260-
print(f"DEBUG: Found norm_added_q key: {pt_key}")
261-
renamed_pt_key = rename_key(pt_key)
262-
if "norm_added_q" in pt_key:
259+
if "norm_added_q" in pt_key and "weight" in pt_key:
263260
parts = pt_key.split(".")
264261
try:
265-
block_idx = int(parts[1])
262+
# Robustly find the block index (handles 'blocks.0' and 'model...blocks.0')
263+
if "blocks" in parts:
264+
block_idx_loc = parts.index("blocks") + 1
265+
block_idx = int(parts[block_idx_loc])
266+
267+
# Transpose and buffer
268+
tensor = tensor.T
269+
norm_added_q_buffer[block_idx] = tensor
270+
else:
271+
print(f"Warning: skipped {pt_key} (no 'blocks' found)")
266272
except ValueError:
267-
print(f"DEBUG: Failed to parse index from {pt_key}")
268-
continue
269-
tensor = tensor.T
270-
norm_added_q_buffer[block_idx] = tensor
273+
print(f"Warning: skipped {pt_key} (index parse error)")
274+
275+
# Skip the standard processing for these keys
271276
continue
277+
renamed_pt_key = rename_key(pt_key)
272278
if "image_embedder" in renamed_pt_key:
273279
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
274280
"net.2" in renamed_pt_key or "net_2" in renamed_pt_key:
@@ -299,10 +305,18 @@ def load_base_wan_transformer(
299305
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
300306
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
301307
if norm_added_q_buffer:
302-
sorted_tensors = [norm_added_q_buffer[i] for i in sorted(norm_added_q_buffer.keys())]
308+
# Sort by block index to ensure correct layer order
309+
sorted_keys = sorted(norm_added_q_buffer.keys())
310+
sorted_tensors = [norm_added_q_buffer[i] for i in sorted_keys]
311+
312+
# Stack into shape (40, ...)
303313
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
314+
315+
# KEY FIX: The error demanded 'kernel', so we hardcode that key here.
304316
final_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
317+
305318
flax_state_dict[final_key] = jax.device_put(stacked_tensor, device=cpu)
319+
print(f"Successfully injected {final_key} with shape {stacked_tensor.shape}")
306320

307321
validate_flax_state_dict(eval_shapes, flax_state_dict)
308322
flax_state_dict = unflatten_dict(flax_state_dict)

0 commit comments

Comments
 (0)