Skip to content

Commit aeca1ce

Browse files
committed
LoRA support for Modulation layer in WAN2.2
1 parent 9b5051c commit aeca1ce

3 files changed

Lines changed: 23 additions & 42 deletions

File tree

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,20 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
610610
return new_state_dict
611611

612612

613+
def preprocess_wan_lora_dict(state_dict):
614+
"""
615+
Preprocesses WAN LoRA dict to convert diff_m to modulation.diff.
616+
"""
617+
new_d = {}
618+
for k, v in state_dict.items():
619+
if k.endswith(".diff_m"):
620+
new_k = k.removesuffix(".diff_m") + ".modulation.diff"
621+
new_d[new_k] = v
622+
else:
623+
new_d[k] = v
624+
return new_d
625+
626+
613627
def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
614628
"""
615629
Translates WAN NNX path to Diffusers/LoRA keys.
@@ -667,7 +681,7 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
667681
"ffn.proj_out": "ffn.2", # Down proj
668682
# Global Norms & Modulation
669683
"norm2.layer_norm": "norm3",
670-
"scale_shift_table": "modulation",
684+
"adaln_scale_shift_table": "modulation",
671685
"proj_out": "head.head",
672686
}
673687

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .lora_pipeline import StableDiffusionLoraLoaderMixin
2020
from ..models import lora_nnx
2121
from .. import max_logging
22-
from . import lora_conversion_utils
22+
from . import lora_conversion_utils, preprocess_wan_lora_dict
2323

2424

2525
class Wan2_1NNXLoraLoader(LoRABaseMixin):
@@ -50,10 +50,10 @@ def load_lora_weights(
5050
def translate_fn(nnx_path_str):
5151
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
5252

53-
# Handle high noise model
5453
if hasattr(pipeline, "transformer") and transformer_weight_name:
5554
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
5655
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
56+
h_state_dict = preprocess_wan_lora_dict(h_state_dict)
5757
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
5858
else:
5959
max_logging.log("transformer not found or no weight name provided for LoRA.")
@@ -94,6 +94,7 @@ def translate_fn(nnx_path_str: str):
9494
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
9595
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
9696
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs)
97+
h_state_dict = preprocess_wan_lora_dict(h_state_dict)
9798
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype)
9899
else:
99100
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
@@ -102,6 +103,7 @@ def translate_fn(nnx_path_str: str):
102103
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
103104
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
104105
l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs)
106+
l_state_dict = preprocess_wan_lora_dict(l_state_dict)
105107
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn, dtype=dtype)
106108
else:
107109
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")

src/maxdiffusion/models/lora_nnx.py

Lines changed: 4 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
@jax.jit
32-
def _compute_and_add_single_jit(kernel, bias, down, up, scale, w_diff, b_diff, m_diff):
32+
def _compute_and_add_single_jit(kernel, bias, down, up, scale, w_diff, b_diff):
3333
"""
3434
Applies LoRA + Weight Diff + Bias Diff on device.
3535
"""
@@ -48,17 +48,11 @@ def _compute_and_add_single_jit(kernel, bias, down, up, scale, w_diff, b_diff, m
4848
if bias is not None and b_diff is not None:
4949
bias = bias + b_diff.astype(bias.dtype)
5050

51-
# 4. Apply DoRA magnitude vector
52-
if m_diff is not None:
53-
kernel = kernel * m_diff.astype(kernel.dtype)
54-
5551
return kernel, bias
5652

5753

5854
@jax.jit
59-
def _compute_and_add_scanned_jit(
60-
kernel, downs, ups, alphas, global_scale, w_diffs=None, b_diffs=None, bias=None, m_diff=None
61-
):
55+
def _compute_and_add_scanned_jit(kernel, downs, ups, alphas, global_scale, w_diffs=None, b_diffs=None, bias=None):
6256
"""
6357
Applies scanned LoRA + Diffs.
6458
"""
@@ -80,14 +74,6 @@ def _compute_and_add_scanned_jit(
8074
if bias is not None and b_diffs is not None:
8175
bias = bias + b_diffs.astype(bias.dtype)
8276

83-
# 4. Apply DoRA magnitude vector
84-
if m_diff is not None:
85-
# Reshape for broadcasting with kernel
86-
# kernel shape can be (L, In, Out) or (L, H, W, In, Out)
87-
# m_diff shape is (L, Out)
88-
new_shape = [m_diff.shape[0]] + [1] * (kernel.ndim - 2) + [m_diff.shape[1]]
89-
kernel = kernel * m_diff.reshape(new_shape).astype(kernel.dtype)
90-
9177
return kernel, bias
9278

9379

@@ -135,14 +121,6 @@ def parse_lora_dict(state_dict, dtype):
135121
lora_params[key_base]["diff"] = _to_jax_array(v, dtype=dtype)
136122
continue
137123

138-
# DoRA Magnitude (e.g., "layer.diff_m")
139-
if k.endswith(".diff_m"):
140-
key_base = k[: -len(".diff_m")]
141-
if key_base not in lora_params:
142-
lora_params[key_base] = {}
143-
lora_params[key_base]["diff_m"] = _to_jax_array(v, dtype=dtype)
144-
continue
145-
146124
# Standard LoRA
147125
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
148126
if not m:
@@ -205,7 +183,6 @@ def _merge_lora_layer(module, weights, scale):
205183
# Prepare Diff terms
206184
w_diff = weights.get("diff", None)
207185
b_diff = weights.get("diff_b", None)
208-
m_diff = weights.get("diff_m", None)
209186

210187
if w_diff is not None:
211188
w_diff = np.array(w_diff)
@@ -219,8 +196,6 @@ def _merge_lora_layer(module, weights, scale):
219196
w_diff = w_diff.transpose((1, 0))
220197
if b_diff is not None:
221198
b_diff = np.array(b_diff)
222-
if m_diff is not None:
223-
m_diff = np.array(m_diff)
224199

225200
# If LoCON, compute delta and add to w_diff
226201
if is_conv_kxk_locon:
@@ -247,9 +222,9 @@ def _merge_lora_layer(module, weights, scale):
247222
bias_val = module.bias.value if module.bias is not None else None
248223

249224
# --- EXECUTE JIT UPDATE ---
250-
if down_w is not None or w_diff is not None or b_diff is not None or m_diff is not None:
225+
if down_w is not None or w_diff is not None or b_diff is not None:
251226
new_kernel, new_bias = _compute_and_add_single_jit(
252-
module.kernel.value, bias_val, down_w, up_w, current_scale, w_diff, b_diff, m_diff
227+
module.kernel.value, bias_val, down_w, up_w, current_scale, w_diff, b_diff
253228
)
254229

255230
module.kernel.value = new_kernel
@@ -404,7 +379,6 @@ def merge_lora_for_scanned(
404379
# Initialize as None, allocate only if found to save memory
405380
stack_w_diff = None
406381
stack_b_diff = None
407-
stack_m_diff = None
408382

409383
has_lora = False
410384
has_diff = False
@@ -415,14 +389,6 @@ def merge_lora_for_scanned(
415389
matched_keys.add(lora_key)
416390
w = lora_params[lora_key]
417391

418-
# --- Fill DoRA Magnitude ---
419-
if "m_diff" in w:
420-
if stack_m_diff is None:
421-
stack_m_diff = np.ones((num_layers, out_feat), dtype=np.float32)
422-
dm = np.array(w["m_diff"])
423-
stack_m_diff[i] = dm.flatten()
424-
has_diff = True
425-
426392
# --- Fill LoRA ---
427393
if "down" in w:
428394
d, u = np.array(w["down"]), np.array(w["up"])
@@ -494,7 +460,6 @@ def merge_lora_for_scanned(
494460
stack_w_diff,
495461
stack_b_diff,
496462
bias_val,
497-
stack_m_diff,
498463
)
499464

500465
module.kernel.value = new_k

0 commit comments

Comments
 (0)