Skip to content

Commit 06cb1ae

Browse files
committed
missing key debug
1 parent cae76f0 commit 06cb1ae

1 file changed

Lines changed: 6 additions & 9 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,14 @@ def load_base_wan_transformer(
215215
raise FileNotFoundError(f"File {index_file_path} not found for local directory.")
216216
local_files = True
217217
elif hf_download:
218+
# download the index file for sharded models.
218219
index_file_path = hf_hub_download(
219220
pretrained_model_name_or_path,
220221
subfolder=subfolder,
221222
filename=filename,
222223
)
223224
with jax.default_device(device):
225+
# open the index file.
224226
with open(index_file_path, "r") as f:
225227
index_dict = json.load(f)
226228
model_files = set()
@@ -234,7 +236,7 @@ def load_base_wan_transformer(
234236
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
235237
else:
236238
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
237-
239+
# now get all the filenames for the model that need downloading
238240
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
239241

240242
if ckpt_shard_path is not None:
@@ -245,31 +247,29 @@ def load_base_wan_transformer(
245247
flax_state_dict = {}
246248
cpu = jax.local_devices(backend="cpu")[0]
247249
flattened_dict = flatten_dict(eval_shapes)
250+
# turn all block numbers to strings just for matching weights.
251+
# Later they will be turned back to ints.
248252
random_flax_state_dict = {}
249253
for key in flattened_dict:
250254
string_tuple = tuple([str(item) for item in key])
251255
random_flax_state_dict[string_tuple] = flattened_dict[key]
252256
del flattened_dict
253257

254-
# 1. Initialize buffer for norm_added_q
255258
norm_added_q_buffer = {}
256259

257260
for pt_key, tensor in tensors.items():
258-
# 2. Robustly Intercept norm_added_q keys
259261
if "norm_added_q" in pt_key and "weight" in pt_key:
260262
parts = pt_key.split(".")
261263
try:
262-
# Find the block index regardless of prefix (blocks.0 vs model.blocks.0)
263264
if "blocks" in parts:
264265
block_idx_loc = parts.index("blocks") + 1
265266
block_idx = int(parts[block_idx_loc])
266267
tensor = tensor.T
267268
norm_added_q_buffer[block_idx] = tensor
268269
except Exception:
269-
pass # Skip if unparseable
270+
pass
270271
continue
271272

272-
# Standard processing
273273
renamed_pt_key = rename_key(pt_key)
274274
if "image_embedder" in renamed_pt_key:
275275
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
@@ -303,13 +303,10 @@ def load_base_wan_transformer(
303303
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
304304
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
305305

306-
# 3. Stack and Insert (Correct Key Name for RMSNorm is 'scale')
307306
if norm_added_q_buffer:
308307
sorted_keys = sorted(norm_added_q_buffer.keys())
309308
sorted_tensors = [norm_added_q_buffer[i] for i in sorted_keys]
310309
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
311-
312-
# 'scale' is the correct parameter name for Flax/NNX RMSNorm
313310
final_key = ('blocks', 'attn2', 'norm_added_q', 'scale')
314311

315312
flax_state_dict[final_key] = jax.device_put(stacked_tensor, device=cpu)

0 commit comments

Comments
 (0)