Skip to content

Commit cae76f0

Browse files
committed
missing key debug
1 parent 45c202d commit cae76f0

2 files changed

Lines changed: 33 additions & 23 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,7 @@ def __init__(
885885
self.add_k_proj = nnx.data(None)
886886
self.add_v_proj = nnx.data(None)
887887
self.norm_added_k = nnx.data(None)
888+
self.norm_added_q = nnx.data(None)
888889
if self.added_kv_proj_dim is not None:
889890
self.add_k_proj = nnx.Linear(
890891
self.added_kv_proj_dim, self.inner_dim, rngs=rngs,
@@ -909,6 +910,13 @@ def __init__(
909910
("norm",),
910911
),
911912
)
913+
self.norm_added_q = nnx.RMSNorm(
914+
num_features=self.inner_dim, rngs=rngs, epsilon=eps, dtype=dtype, param_dtype=weights_dtype,
915+
scale_init=nnx.with_partitioning(
916+
nnx.initializers.ones,
917+
("norm",),
918+
),
919+
)
912920

913921
def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]:
914922
dtype = xq.dtype
@@ -1016,8 +1024,10 @@ def __call__(
10161024
# Attention - tensors are (B, S, D)
10171025
with self.conditional_named_scope("cross_attn_text_apply"):
10181026
attn_output_text = self.attention_op.apply_attention(query_proj, key_proj_text, value_proj_text)
1027+
with self.conditional_named_scope("norm_added_q"):
1028+
query_proj_img = self.norm_added_q(query_proj)
10191029
with self.conditional_named_scope("cross_attn_img_apply"):
1020-
attn_output_img = self.attention_op.apply_attention(query_proj, key_proj_img, value_proj_img)
1030+
attn_output_img = self.attention_op.apply_attention(query_proj_img, key_proj_img, value_proj_img)
10211031

10221032
attn_output = attn_output_text + attn_output_img
10231033

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,12 @@ def load_base_wan_transformer(
215215
raise FileNotFoundError(f"File {index_file_path} not found for local directory.")
216216
local_files = True
217217
elif hf_download:
218-
# download the index file for sharded models.
219218
index_file_path = hf_hub_download(
220219
pretrained_model_name_or_path,
221220
subfolder=subfolder,
222221
filename=filename,
223222
)
224223
with jax.default_device(device):
225-
# open the index file.
226224
with open(index_file_path, "r") as f:
227225
index_dict = json.load(f)
228226
model_files = set()
@@ -236,37 +234,42 @@ def load_base_wan_transformer(
236234
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
237235
else:
238236
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
239-
# now get all the filenames for the model that need downloading
237+
240238
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
241239

242240
if ckpt_shard_path is not None:
243241
with safe_open(ckpt_shard_path, framework="pt") as f:
244242
for k in f.keys():
245243
tensors[k] = torch2jax(f.get_tensor(k))
244+
246245
flax_state_dict = {}
247246
cpu = jax.local_devices(backend="cpu")[0]
248-
flattened_eval_shapes = flatten_dict(eval_shapes)
249-
# turn all block numbers to strings just for matching weights.
250-
# Later they will be turned back to ints.
247+
flattened_dict = flatten_dict(eval_shapes)
251248
random_flax_state_dict = {}
252-
for key in flattened_eval_shapes:
249+
for key in flattened_dict:
253250
string_tuple = tuple([str(item) for item in key])
254-
random_flax_state_dict[string_tuple] = flattened_eval_shapes[key]
255-
# del flattened_dict
251+
random_flax_state_dict[string_tuple] = flattened_dict[key]
252+
del flattened_dict
253+
254+
# 1. Initialize buffer for norm_added_q
256255
norm_added_q_buffer = {}
256+
257257
for pt_key, tensor in tensors.items():
258+
# 2. Robustly Intercept norm_added_q keys
258259
if "norm_added_q" in pt_key and "weight" in pt_key:
259260
parts = pt_key.split(".")
260261
try:
262+
# Find the block index regardless of prefix (blocks.0 vs model.blocks.0)
261263
if "blocks" in parts:
262264
block_idx_loc = parts.index("blocks") + 1
263265
block_idx = int(parts[block_idx_loc])
264266
tensor = tensor.T
265267
norm_added_q_buffer[block_idx] = tensor
266268
except Exception:
267-
pass
269+
pass # Skip if unparseable
268270
continue
269-
271+
272+
# Standard processing
270273
renamed_pt_key = rename_key(pt_key)
271274
if "image_embedder" in renamed_pt_key:
272275
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
@@ -287,39 +290,36 @@ def load_base_wan_transformer(
287290
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
288291
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
289292
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
293+
290294
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
291295
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
292296
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
293297
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
294298
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
295299
if "norm2.layer_norm" not in renamed_pt_key:
296300
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
301+
297302
pt_tuple_key = tuple(renamed_pt_key.split("."))
298303
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
299304
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
305+
306+
# 3. Stack and Insert (Correct Key Name for RMSNorm is 'scale')
300307
if norm_added_q_buffer:
301308
sorted_keys = sorted(norm_added_q_buffer.keys())
302309
sorted_tensors = [norm_added_q_buffer[i] for i in sorted_keys]
303310
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
304-
final_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
311+
312+
# 'scale' is the correct parameter name for Flax/NNX RMSNorm
313+
final_key = ('blocks', 'attn2', 'norm_added_q', 'scale')
314+
305315
flax_state_dict[final_key] = jax.device_put(stacked_tensor, device=cpu)
306-
print(f"DEBUG: Manually injected {final_key} into flax_state_dict")
307-
if final_key not in flattened_eval_shapes:
308-
print(f"DEBUG: Key {final_key} missing in eval_shapes. Patching it now.")
309-
shape_struct = jax.ShapeDtypeStruct(
310-
shape=stacked_tensor.shape,
311-
dtype=stacked_tensor.dtype
312-
)
313-
flattened_eval_shapes[final_key] = shape_struct
314-
eval_shapes = unflatten_dict(flattened_eval_shapes)
315316

316317
validate_flax_state_dict(eval_shapes, flax_state_dict)
317318
flax_state_dict = unflatten_dict(flax_state_dict)
318319
del tensors
319320
jax.clear_caches()
320321
return flax_state_dict
321322

322-
323323
def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True):
324324
device = jax.devices(device)[0]
325325
subfolder = "vae"

0 commit comments

Comments
 (0)