We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f05c97b commit f39796fCopy full SHA for f39796f
1 file changed
src/maxdiffusion/models/wan/wan_utils.py
@@ -254,11 +254,18 @@ def load_base_wan_transformer(
254
random_flax_state_dict[string_tuple] = flattened_dict[key]
255
del flattened_dict
256
norm_added_q_buffer = {}
257
+ print(f"DEBUG: Total keys found in checkpoint: {len(tensors)}")
258
for pt_key, tensor in tensors.items():
259
+ if "norm_added_q" in pt_key:
260
+ print(f"DEBUG: Found norm_added_q key: {pt_key}")
261
renamed_pt_key = rename_key(pt_key)
262
if "norm_added_q" in pt_key:
263
parts = pt_key.split(".")
- block_idx = int(parts[1])
264
+ try:
265
+ block_idx = int(parts[1])
266
+ except ValueError:
267
+ print(f"DEBUG: Failed to parse index from {pt_key}")
268
+ continue
269
tensor = tensor.T
270
norm_added_q_buffer[block_idx] = tensor
271
continue
0 commit comments