Skip to content

Commit 45c202d

Browse files
committed
missing key debug
1 parent 765f4bf commit 45c202d

1 file changed

Lines changed: 16 additions & 33 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -245,16 +245,15 @@ def load_base_wan_transformer(
245245
tensors[k] = torch2jax(f.get_tensor(k))
246246
flax_state_dict = {}
247247
cpu = jax.local_devices(backend="cpu")[0]
248-
flattened_dict = flatten_dict(eval_shapes)
248+
flattened_eval_shapes = flatten_dict(eval_shapes)
249249
# turn all block numbers to strings just for matching weights.
250250
# Later they will be turned back to ints.
251251
random_flax_state_dict = {}
252-
for key in flattened_dict:
252+
for key in flattened_eval_shapes:
253253
string_tuple = tuple([str(item) for item in key])
254-
random_flax_state_dict[string_tuple] = flattened_dict[key]
254+
random_flax_state_dict[string_tuple] = flattened_eval_shapes[key]
255255
# del flattened_dict
256256
norm_added_q_buffer = {}
257-
print(f"DEBUG: Total keys found in checkpoint: {len(tensors)}")
258257
for pt_key, tensor in tensors.items():
259258
if "norm_added_q" in pt_key and "weight" in pt_key:
260259
parts = pt_key.split(".")
@@ -264,8 +263,8 @@ def load_base_wan_transformer(
264263
block_idx = int(parts[block_idx_loc])
265264
tensor = tensor.T
266265
norm_added_q_buffer[block_idx] = tensor
267-
except Exception as e:
268-
print(f"Warning: skipped {pt_key} due to {e}")
266+
except Exception:
267+
pass
269268
continue
270269

271270
renamed_pt_key = rename_key(pt_key)
@@ -302,33 +301,17 @@ def load_base_wan_transformer(
302301
sorted_keys = sorted(norm_added_q_buffer.keys())
303302
sorted_tensors = [norm_added_q_buffer[i] for i in sorted_keys]
304303
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
305-
306-
target_key = None
307-
print("DEBUG: Searching eval_shapes for norm_added_q...")
308-
possible_keys = []
309-
310-
for key_tuple in flattened_dict.keys():
311-
if "norm_added_q" in key_tuple:
312-
possible_keys.append(key_tuple)
313-
314-
if len(possible_keys) > 0:
315-
# Pick the first one (should only be one for this specific layer)
316-
target_key = possible_keys[0]
317-
print(f"DEBUG: Found matching key in eval_shapes: {target_key}")
318-
flax_state_dict[target_key] = jax.device_put(stacked_tensor, device=cpu)
319-
else:
320-
# If we still find nothing, print ALL keys to debug for the user
321-
print("CRITICAL ERROR: 'norm_added_q' NOT FOUND in eval_shapes.")
322-
print("DEBUG: Dumping sample keys from eval_shapes to help debug:")
323-
for i, k in enumerate(list(flattened_dict.keys())[:20]):
324-
print(f" {k}")
325-
326-
# Last resort fallback
327-
manual_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
328-
print(f"DEBUG: Attempting manual injection to {manual_key}")
329-
flax_state_dict[manual_key] = jax.device_put(stacked_tensor, device=cpu)
330-
331-
del flattened_dict
304+
final_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
305+
flax_state_dict[final_key] = jax.device_put(stacked_tensor, device=cpu)
306+
print(f"DEBUG: Manually injected {final_key} into flax_state_dict")
307+
if final_key not in flattened_eval_shapes:
308+
print(f"DEBUG: Key {final_key} missing in eval_shapes. Patching it now.")
309+
shape_struct = jax.ShapeDtypeStruct(
310+
shape=stacked_tensor.shape,
311+
dtype=stacked_tensor.dtype
312+
)
313+
flattened_eval_shapes[final_key] = shape_struct
314+
eval_shapes = unflatten_dict(flattened_eval_shapes)
332315

333316
validate_flax_state_dict(eval_shapes, flax_state_dict)
334317
flax_state_dict = unflatten_dict(flax_state_dict)

0 commit comments

Comments
 (0)