Skip to content

Commit 765f4bf

Browse files
committed
missing key debug
1 parent 0b4e348 commit 765f4bf

1 file changed

Lines changed: 20 additions & 13 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -304,22 +304,29 @@ def load_base_wan_transformer(
304304
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
305305

306306
target_key = None
307+
print("DEBUG: Searching eval_shapes for norm_added_q...")
308+
possible_keys = []
307309

308310
for key_tuple in flattened_dict.keys():
309-
# Check if this tuple looks like what we want
310-
# We check if it ends with 'norm_added_q' and 'kernel'
311-
if len(key_tuple) >= 2 and key_tuple[-2:] == ('norm_added_q', 'kernel'):
312-
target_key = key_tuple
313-
break
314-
if target_key:
315-
print(f"DEBUG: Found authoritative key in eval_shapes: {target_key}")
316-
flax_state_dict[target_key] = jax.device_put(stacked_tensor, device=cpu)
317-
print(f"Successfully injected norm_added_q with shape {stacked_tensor.shape}")
311+
if "norm_added_q" in key_tuple:
312+
possible_keys.append(key_tuple)
313+
314+
if len(possible_keys) > 0:
315+
# Pick the first one (should only be one for this specific layer)
316+
target_key = possible_keys[0]
317+
print(f"DEBUG: Found matching key in eval_shapes: {target_key}")
318+
flax_state_dict[target_key] = jax.device_put(stacked_tensor, device=cpu)
318319
else:
319-
# Fallback (should typically not happen if error message was correct)
320-
print("CRITICAL WARNING: Could not find norm_added_q key in eval_shapes! Using manual fallback.")
321-
manual_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
322-
flax_state_dict[manual_key] = jax.device_put(stacked_tensor, device=cpu)
320+
# If we still find nothing, print ALL keys to debug for the user
321+
print("CRITICAL ERROR: 'norm_added_q' NOT FOUND in eval_shapes.")
322+
print("DEBUG: Dumping sample keys from eval_shapes to help debug:")
323+
for i, k in enumerate(list(flattened_dict.keys())[:20]):
324+
print(f" {k}")
325+
326+
# Last resort fallback
327+
manual_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
328+
print(f"DEBUG: Attempting manual injection to {manual_key}")
329+
flax_state_dict[manual_key] = jax.device_put(stacked_tensor, device=cpu)
323330

324331
del flattened_dict
325332

0 commit comments

Comments
 (0)