Skip to content

Commit d4e22a4

Browse files
committed
Fix
1 parent 95c985b commit d4e22a4

2 files changed

Lines changed: 60 additions & 57 deletions

File tree

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -612,16 +612,11 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
612612

613613
def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
614614
"""
615-
Translates WAN NNX path like 'blocks.10.attn1.key' (scan_layers=False) or
616-
'blocks.attn1.key' (scan_layers=True) to
617-
LoRA path like 'diffusion_model.blocks.10.self_attn.k' or
618-
template 'diffusion_model.blocks.{}.self_attn.k'.
619-
Returns None if no match.
615+
Translates WAN NNX path to Diffusers/LoRA keys.
616+
Verified against wan_utils.py mappings.
620617
"""
621618

622-
# Handle embeddings - exact paths
623-
if nnx_path_str == "patch_embedding":
624-
return "diffusion_model.patch_embedding"
619+
# --- 1. Embeddings (Exact Matches) ---
625620
if nnx_path_str == 'condition_embedder.text_embedder.linear_1':
626621
return 'diffusion_model.text_embedding.0'
627622
if nnx_path_str == 'condition_embedder.text_embedder.linear_2':
@@ -630,46 +625,55 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
630625
return 'diffusion_model.time_embedding.0'
631626
if nnx_path_str == 'condition_embedder.time_embedder.linear_2':
632627
return 'diffusion_model.time_embedding.2'
633-
634-
# Translation for Attention and FFN layers
635-
attn_ffn_map = {
636-
"attn1.query": "self_attn.q",
637-
"attn1.key": "self_attn.k",
638-
"attn1.value": "self_attn.v",
628+
if nnx_path_str == 'patch_embedding':
629+
return 'diffusion_model.patch_embedding'
630+
631+
# --- 2. Map NNX Suffixes to LoRA Suffixes ---
632+
suffix_map = {
633+
# Self Attention (attn1)
634+
"attn1.query": "self_attn.q",
635+
"attn1.key": "self_attn.k",
636+
"attn1.value": "self_attn.v",
639637
"attn1.proj_attn": "self_attn.o",
640-
"attn2.query": "cross_attn.q",
641-
"attn2.key": "cross_attn.k",
642-
"attn2.value": "cross_attn.v",
638+
639+
# Self Attention Norms (QK Norm) - Added per your request
640+
"attn1.norm_q": "self_attn.norm_q",
641+
"attn1.norm_k": "self_attn.norm_k",
642+
643+
# Cross Attention (attn2)
644+
"attn2.query": "cross_attn.q",
645+
"attn2.key": "cross_attn.k",
646+
"attn2.value": "cross_attn.v",
643647
"attn2.proj_attn": "cross_attn.o",
644-
"ffn.act_fn.proj": "ffn.0",
645-
"ffn.proj_out": "ffn.2",
646-
}
647-
# Translation for Norm layers
648-
norm_map = {
649-
"norm3": "norm3",
650-
"attn1.norm_q": "self_attn.norm_q",
651-
"attn1.norm_k": "self_attn.norm_k",
652-
"attn2.norm_q": "cross_attn.norm_q",
653-
"attn2.norm_k": "cross_attn.norm_k",
648+
649+
# Cross Attention Norms (QK Norm) - Added per your request
650+
"attn2.norm_q": "cross_attn.norm_q",
651+
"attn2.norm_k": "cross_attn.norm_k",
652+
653+
# Feed Forward (ffn)
654+
"ffn.act_fn.proj": "ffn.0", # Up proj
655+
"ffn.proj_out": "ffn.2", # Down proj
656+
657+
# Global Norms & Modulation
658+
"norm2.layer_norm": "norm3",
659+
"scale_shift_table": "modulation",
660+
"proj_out": "head.head"
654661
}
655662

663+
# --- 3. Translation Logic ---
656664
if scan_layers:
657-
# Handle scanned attn/ffn: blocks.attn1.query -> diffusion_model.blocks.{}.self_attn.q
658-
for k, v in attn_ffn_map.items():
659-
if nnx_path_str == f"blocks.{k}":
660-
return f"diffusion_model.blocks.{{}}.{v}"
661-
# Handle scanned norm: blocks.norm3 -> diffusion_model.blocks.{}.norm3
662-
for k, v in norm_map.items():
663-
if nnx_path_str == f"blocks.{k}":
664-
return f"diffusion_model.blocks.{{}}.{v}"
665+
# Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
666+
if nnx_path_str.startswith("blocks."):
667+
inner_suffix = nnx_path_str[len("blocks."):]
668+
if inner_suffix in suffix_map:
669+
return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}"
665670
else:
666-
# Handle non-scanned attn/ffn/norm: blocks.0.attn1.query -> diffusion_model.blocks.0.self_attn.q
671+
# Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
667672
m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str)
668673
if m:
669-
idx, suffix = m.group(1), m.group(2)
670-
if suffix in attn_ffn_map:
671-
return f"diffusion_model.blocks.{idx}.{attn_ffn_map[suffix]}"
672-
if suffix in norm_map:
673-
return f"diffusion_model.blocks.{idx}.{norm_map[suffix]}"
674+
idx, inner_suffix = m.group(1), m.group(2)
675+
if inner_suffix in suffix_map:
676+
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
674677

675678
return None
679+

src/maxdiffusion/models/lora_nnx.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def _compute_and_add_scanned_jit(kernel, downs, ups, alphas, global_scale, w_dif
6161
scales = (global_scale * alphas / rank)
6262
# Batch Matmul: (L, In, Out)
6363
delta = jnp.matmul(jnp.swapaxes(downs, 1, 2), jnp.swapaxes(ups, 1, 2))
64-
delta = delta.reshape(kernel.shape)
65-
kernel = kernel + (delta * scales).astype(kernel.dtype)
64+
delta = (delta * scales).astype(kernel.dtype)
65+
kernel = kernel + delta.reshape(kernel.shape)
6666

6767
# 2. Apply Scanned Weight Diffs (L, ...)
6868
if w_diffs is not None:
@@ -227,7 +227,6 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
227227
if not isinstance(module, (nnx.Linear, nnx.Conv)): continue
228228

229229
nnx_path_str = ".".join(map(str, path))
230-
max_logging.log(f"NNX path: {nnx_path_str}")
231230
lora_key = translate_fn(nnx_path_str) if translate_fn else None
232231

233232
if lora_key and lora_key in lora_params:
@@ -236,19 +235,19 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
236235
# Prepare LoRA terms
237236
down_w, up_w, current_scale = None, None, None
238237
if "down" in weights and "up" in weights:
239-
down_w, up_w = weights["down"], weights["up"]
240-
down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert
241-
242-
# Squeeze dimensions if needed (Conv 1x1 or Linear)
243-
if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1):
244-
down_w, up_w = np.squeeze(down_w), np.squeeze(up_w)
245-
elif isinstance(module, nnx.Conv) and module.kernel_size != (1, 1):
246-
# Skip LoRA for non-1x1 convs if shapes don't align
247-
pass
248-
249-
rank = down_w.shape[0] if down_w.ndim > 0 else 0
250-
alpha = float(weights.get("alpha", rank))
251-
current_scale = scale * alpha / rank
238+
if isinstance(module, nnx.Conv) and module.kernel_size != (1, 1):
239+
max_logging.log(f"Skipping LoRA merge for non-1x1 Conv: {lora_key}")
240+
else:
241+
down_w, up_w = weights["down"], weights["up"]
242+
down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert
243+
244+
# Squeeze dimensions if needed (Conv 1x1 or Linear)
245+
if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1):
246+
down_w, up_w = np.squeeze(down_w), np.squeeze(up_w)
247+
248+
rank = down_w.shape[0] if down_w.ndim > 0 else 0
249+
alpha = float(weights.get("alpha", rank))
250+
current_scale = scale * alpha / rank
252251

253252
# Prepare Diff terms
254253
w_diff = weights.get("diff", None)

0 commit comments

Comments
 (0)