Skip to content

Commit 9390582

Browse files
committed
Fix
1 parent 77b46b6 commit 9390582

2 files changed

Lines changed: 157 additions & 41 deletions

File tree

src/maxdiffusion/loaders/wan_lora_nnx_loader.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,78 @@
1616

1717
from flax import nnx
1818
import jax
19+
import re
1920
from .lora_base import LoRABaseMixin
2021
from .lora_pipeline import StableDiffusionLoraLoaderMixin
2122
from ..models import lora_nnx
2223
from .. import max_logging
2324

25+
def _translate_nnx_path_to_lora_key(nnx_path_str):
26+
"""
27+
Translates NNX path like 'blocks.10.attn1.key' to
28+
LoRA path like 'diffusion_model.blocks.10.self_attn.k'.
29+
Returns None if no match.
30+
"""
31+
translation_map = {
32+
"attn1": "self_attn",
33+
"attn2": "cross_attn",
34+
"query": "q",
35+
"key": "k",
36+
"value": "v",
37+
"proj_attn": "o",
38+
"ffn.act_fn.proj": "ffn.0",
39+
"ffn.proj_out": "ffn.2",
40+
}
41+
# Match paths like blocks.10.attn1.key or blocks.5.ffn.proj_out
42+
m = re.match(r"^blocks\.(\d+)\.(attn[12]\.(?:query|key|value|proj_attn)|ffn\.(?:act_fn\.proj|proj_out))$", nnx_path_str)
43+
if not m:
44+
return None
45+
46+
block_idx, suffix = m.group(1), m.group(2)
47+
48+
parts = suffix.split('.')
49+
if parts[0] == 'attn1' or parts[0] == 'attn2':
50+
lora_part1 = translation_map[parts[0]]
51+
lora_part2 = translation_map[parts[1]]
52+
return f"diffusion_model.blocks.{block_idx}.{lora_part1}.{lora_part2}"
53+
elif suffix in translation_map:
54+
return f"diffusion_model.blocks.{block_idx}.{translation_map[suffix]}"
55+
return None
56+
57+
58+
def _translate_scanned_nnx_path_to_lora_key_template(nnx_path_str):
59+
"""
60+
Translates NNX path like 'blocks.attn1.key' to
61+
LoRA path template like 'diffusion_model.blocks.{}.self_attn.k'.
62+
Returns None if no match.
63+
This version assumes block index is missing from path due to scan.
64+
"""
65+
translation_map = {
66+
"attn1": "self_attn",
67+
"attn2": "cross_attn",
68+
"query": "q",
69+
"key": "k",
70+
"value": "v",
71+
"proj_attn": "o",
72+
"ffn.act_fn.proj": "ffn.0",
73+
"ffn.proj_out": "ffn.2",
74+
}
75+
# Match paths like blocks.attn1.key or blocks.ffn.proj_out, missing block index
76+
m = re.match(r"^blocks\.(attn[12]\.(?:query|key|value|proj_attn)|ffn\.(?:act_fn\.proj|proj_out))$", nnx_path_str)
77+
if not m:
78+
return None
79+
80+
suffix = m.group(1)
81+
82+
parts = suffix.split('.')
83+
if parts[0] == 'attn1' or parts[0] == 'attn2':
84+
lora_part1 = translation_map[parts[0]]
85+
lora_part2 = translation_map[parts[1]]
86+
return f"diffusion_model.blocks.{{}}.{lora_part1}.{lora_part2}"
87+
elif suffix in translation_map:
88+
return f"diffusion_model.blocks.{{}}.{translation_map[suffix]}"
89+
return None
90+
2491
class WanNnxLoraLoader(LoRABaseMixin):
2592
"""
2693
Handles loading LoRA weights into NNX-based WAN models.
@@ -36,20 +103,24 @@ def load_lora_weights(
36103
low_noise_weight_name: str,
37104
rank: int,
38105
scale: float = 1.0,
106+
scan_layers: bool = False,
39107
**kwargs,
40108
):
41109
"""
42110
Merges LoRA weights into the pipeline from a checkpoint.
43111
"""
44112
lora_loader = StableDiffusionLoraLoaderMixin()
45113

114+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
115+
translate_fn = _translate_scanned_nnx_path_to_lora_key_template if scan_layers else _translate_nnx_path_to_lora_key
116+
46117
# Handle high noise model
47118
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
48119
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
49120
h_state_dict, _ = lora_loader.lora_state_dict(
50121
lora_model_path, weight_name=high_noise_weight_name, **kwargs
51122
)
52-
lora_nnx.merge_lora(pipeline.high_noise_transformer, h_state_dict, scale)
123+
merge_fn(pipeline.high_noise_transformer, h_state_dict, scale, translate_fn)
53124
else:
54125
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")
55126

@@ -59,7 +130,7 @@ def load_lora_weights(
59130
l_state_dict, _ = lora_loader.lora_state_dict(
60131
lora_model_path, weight_name=low_noise_weight_name, **kwargs
61132
)
62-
lora_nnx.merge_lora(pipeline.low_noise_transformer, l_state_dict, scale)
133+
merge_fn(pipeline.low_noise_transformer, l_state_dict, scale, translate_fn)
63134
else:
64135
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")
65136

src/maxdiffusion/models/lora_nnx.py

Lines changed: 84 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -200,40 +200,7 @@ def inject_lora(
200200

201201
return model
202202

203-
def _translate_nnx_path_to_lora_key(nnx_path_str):
204-
"""
205-
Translates NNX path like 'blocks.10.attn1.key' to
206-
LoRA path like 'diffusion_model.blocks.10.self_attn.k'.
207-
Returns None if no match.
208-
"""
209-
translation_map = {
210-
"attn1": "self_attn",
211-
"attn2": "cross_attn",
212-
"query": "q",
213-
"key": "k",
214-
"value": "v",
215-
"proj_attn": "o",
216-
"ffn.act_fn.proj": "ffn.0",
217-
"ffn.proj_out": "ffn.2",
218-
}
219-
# Match paths like blocks.10.attn1.key or blocks.5.ffn.proj_out
220-
m = re.match(r"^blocks\.(\d+)\.(attn[12]\.(?:query|key|value|proj_attn)|ffn\.(?:act_fn\.proj|proj_out))$", nnx_path_str)
221-
if not m:
222-
return None
223-
224-
block_idx, suffix = m.group(1), m.group(2)
225-
226-
parts = suffix.split('.')
227-
if parts[0] == 'attn1' or parts[0] == 'attn2':
228-
lora_part1 = translation_map[parts[0]]
229-
lora_part2 = translation_map[parts[1]]
230-
return f"diffusion_model.blocks.{block_idx}.{lora_part1}.{lora_part2}"
231-
elif suffix in translation_map:
232-
return f"diffusion_model.blocks.{block_idx}.{translation_map[suffix]}"
233-
return None
234-
235-
236-
def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
203+
def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=None):
237204
"""
238205
Merges weights from a Diffusers-formatted state dict directly
239206
into the kernel of nnx.Linear and nnx.Conv layers.
@@ -271,15 +238,12 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
271238
assigned_count = 0
272239
for path, module in nnx.iter_graph(model):
273240
if not isinstance(module, (nnx.Linear, nnx.Conv)):
274-
max_logging.log(f"Skipping non-Linear/Conv layer: {module}")
275241
continue
276242

277243
nnx_path_str = ".".join(map(str, path))
278-
max_logging.log(f"Checking NNX layer: {nnx_path_str}")
279-
lora_key = _translate_nnx_path_to_lora_key(nnx_path_str)
244+
lora_key = translate_fn(nnx_path_str) if translate_fn else None
280245

281246
if lora_key and lora_key in lora_params:
282-
max_logging.log(f"NNX layer '{nnx_path_str}' matched LoRA key '{lora_key}'")
283247
weights = lora_params[lora_key]
284248
if "down" in weights and "up" in weights:
285249
if isinstance(module, nnx.Linear):
@@ -308,4 +272,85 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
308272
else:
309273
max_logging.log(f"NNX layer '{nnx_path_str}' could not be translated to a LoRA key.")
310274

311-
max_logging.log(f"Merged weights into {assigned_count} layers in {type(model).__name__}.")
275+
max_logging.log(f"Merged weights into {assigned_count} layers in {type(model).__name__}.")
276+
277+
278+
def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, translate_fn=None):
279+
"""
280+
Merges weights from a Diffusers-formatted state dict directly
281+
into the kernel of nnx.Linear and nnx.Conv layers.
282+
Assumes scan_layers=True, so weights are stacked if layers are scanned
283+
(e.g. kernel.ndim=3 for Linear).
284+
"""
285+
lora_params = {}
286+
# Parse weights and alphas
287+
for k, v in state_dict.items():
288+
if k.endswith(".alpha"):
289+
key_base = k[:-len(".alpha")]
290+
if key_base not in lora_params:
291+
lora_params[key_base] = {}
292+
lora_params[key_base]["alpha"] = jnp.array(v)
293+
continue
294+
295+
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
296+
if m:
297+
key_base, weight_type = m.group(1), m.group(2)
298+
else:
299+
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
300+
if m:
301+
key_base, weight_type = m.group(1), m.group(2)
302+
else:
303+
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
304+
if m:
305+
key_base, weight_type = m.group(1), m.group(2).replace("lora_", "")
306+
else:
307+
max_logging.log(f"Could not parse LoRA key: {k}")
308+
continue
309+
if key_base not in lora_params:
310+
lora_params[key_base] = {}
311+
lora_params[key_base][weight_type] = jnp.array(v)
312+
max_logging.log(f"Parsed {len(lora_params)} unique LoRA module keys for scanned merge.")
313+
314+
assigned_count = 0
315+
for path, module in nnx.iter_graph(model):
316+
if not isinstance(module, (nnx.Linear, nnx.Conv)):
317+
continue
318+
319+
nnx_path_str = ".".join(map(str, path))
320+
321+
# Handle scanned Linear layers
322+
if isinstance(module, nnx.Linear) and module.kernel.ndim == 3:
323+
lora_key_template = translate_fn(nnx_path_str) if translate_fn else None
324+
325+
if lora_key_template:
326+
num_layers, in_features, out_features = module.kernel.shape
327+
deltas = []
328+
has_lora = False
329+
for i in range(num_layers):
330+
lora_key = lora_key_template.format(i)
331+
if lora_key in lora_params and "down" in lora_params[lora_key] and "up" in lora_params[lora_key]:
332+
weights = lora_params[lora_key]
333+
down_w, up_w = weights["down"], weights["up"]
334+
rank = down_w.shape[0]
335+
alpha = weights.get("alpha", rank)
336+
current_scale = scale * alpha / rank
337+
delta_i = (down_w.T @ up_w.T).reshape(in_features, out_features) * current_scale
338+
deltas.append(delta_i)
339+
has_lora = True
340+
else:
341+
deltas.append(jnp.zeros((in_features, out_features), dtype=module.kernel.dtype))
342+
343+
if has_lora:
344+
stacked_delta = jnp.stack(deltas, axis=0)
345+
module.kernel.value += stacked_delta
346+
assigned_count += 1
347+
else:
348+
max_logging.log(f"Scanned layer {nnx_path_str} matched template but no LoRA weights found for any block.")
349+
else:
350+
max_logging.log(f"Scanned NNX layer '{nnx_path_str}' could not be translated to a LoRA key template.")
351+
352+
# Handle scanned Conv layers (ndim=5)
353+
elif isinstance(module, nnx.Conv) and module.kernel.ndim == 5:
354+
max_logging.warn(f"Merging LoRA into scanned Conv layers not implemented: {nnx_path_str}")
355+
356+
max_logging.log(f"Merged weights into {assigned_count} scanned layers in {type(model).__name__}.")

0 commit comments

Comments
 (0)