Skip to content

Commit 0b4e348

Browse files
committed
missing key debug
1 parent 635f6fc commit 0b4e348

1 file changed

Lines changed: 23 additions & 18 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -252,28 +252,22 @@ def load_base_wan_transformer(
252252
for key in flattened_dict:
253253
string_tuple = tuple([str(item) for item in key])
254254
random_flax_state_dict[string_tuple] = flattened_dict[key]
255-
del flattened_dict
255+
# del flattened_dict
256256
norm_added_q_buffer = {}
257257
print(f"DEBUG: Total keys found in checkpoint: {len(tensors)}")
258258
for pt_key, tensor in tensors.items():
259259
if "norm_added_q" in pt_key and "weight" in pt_key:
260260
parts = pt_key.split(".")
261261
try:
262-
# Robustly find the block index (handles 'blocks.0' and 'model...blocks.0')
263262
if "blocks" in parts:
264263
block_idx_loc = parts.index("blocks") + 1
265264
block_idx = int(parts[block_idx_loc])
266-
267-
# Transpose and buffer
268265
tensor = tensor.T
269266
norm_added_q_buffer[block_idx] = tensor
270-
else:
271-
print(f"Warning: skipped {pt_key} (no 'blocks' found)")
272-
except ValueError:
273-
print(f"Warning: skipped {pt_key} (index parse error)")
274-
275-
# Skip the standard processing for these keys
267+
except Exception as e:
268+
print(f"Warning: skipped {pt_key} due to {e}")
276269
continue
270+
277271
renamed_pt_key = rename_key(pt_key)
278272
if "image_embedder" in renamed_pt_key:
279273
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
@@ -305,18 +299,29 @@ def load_base_wan_transformer(
305299
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
306300
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
307301
if norm_added_q_buffer:
308-
# Sort by block index to ensure correct layer order
309302
sorted_keys = sorted(norm_added_q_buffer.keys())
310303
sorted_tensors = [norm_added_q_buffer[i] for i in sorted_keys]
311-
312-
# Stack into shape (40, ...)
313304
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
305+
306+
target_key = None
307+
308+
for key_tuple in flattened_dict.keys():
309+
# Check if this tuple looks like what we want
310+
# We check if it ends with 'norm_added_q' and 'kernel'
311+
if len(key_tuple) >= 2 and key_tuple[-2:] == ('norm_added_q', 'kernel'):
312+
target_key = key_tuple
313+
break
314+
if target_key:
315+
print(f"DEBUG: Found authoritative key in eval_shapes: {target_key}")
316+
flax_state_dict[target_key] = jax.device_put(stacked_tensor, device=cpu)
317+
print(f"Successfully injected norm_added_q with shape {stacked_tensor.shape}")
318+
else:
319+
# Fallback (should typically not happen if error message was correct)
320+
print("CRITICAL WARNING: Could not find norm_added_q key in eval_shapes! Using manual fallback.")
321+
manual_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
322+
flax_state_dict[manual_key] = jax.device_put(stacked_tensor, device=cpu)
314323

315-
# KEY FIX: The error demanded 'kernel', so we hardcode that key here.
316-
final_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
317-
318-
flax_state_dict[final_key] = jax.device_put(stacked_tensor, device=cpu)
319-
print(f"Successfully injected {final_key} with shape {stacked_tensor.shape}")
324+
del flattened_dict
320325

321326
validate_flax_state_dict(eval_shapes, flax_state_dict)
322327
flax_state_dict = unflatten_dict(flax_state_dict)

0 commit comments

Comments
 (0)