Skip to content

Commit 6042d24

Browse files
committed
Fix
1 parent d52e25d commit 6042d24

2 files changed

Lines changed: 61 additions & 42 deletions

File tree

src/maxdiffusion/configs/base_wan_lora_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ replicate_vae: False
5050
# at the cost of time.
5151
precision: "DEFAULT"
5252
# Use jax.lax.scan for transformer layers
53-
scan_layers: True
53+
scan_layers: False
5454

5555
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
5656
# It must be True for multi-host.

src/maxdiffusion/models/lora_nnx.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -200,42 +200,73 @@ def inject_lora(
200200

201201
return model
202202

203+
def _translate_nnx_path_to_lora_key(nnx_path_str):
204+
"""
205+
Translates NNX path like 'blocks.10.attn1.key' to
206+
LoRA path like 'diffusion_model.blocks.10.self_attn.k'.
207+
Returns None if no match.
208+
"""
209+
translation_map = {
210+
"attn1": "self_attn",
211+
"attn2": "cross_attn",
212+
"query": "q",
213+
"key": "k",
214+
"value": "v",
215+
"proj_attn": "o",
216+
"ffn.act_fn.proj": "ffn.0",
217+
"ffn.proj_out": "ffn.2",
218+
}
219+
# Match paths like blocks.10.attn1.key or blocks.5.ffn.proj_out
220+
m = re.match(r"^blocks\.(\d+)\.(attn[12]\.(?:query|key|value|proj_attn)|ffn\.(?:act_fn\.proj|proj_out))$", nnx_path_str)
221+
if not m:
222+
return None
223+
224+
block_idx, suffix = m.group(1), m.group(2)
225+
226+
parts = suffix.split('.')
227+
if parts[0] == 'attn1' or parts[0] == 'attn2':
228+
lora_part1 = translation_map[parts[0]]
229+
lora_part2 = translation_map[parts[1]]
230+
return f"diffusion_model.blocks.{block_idx}.{lora_part1}.{lora_part2}"
231+
elif suffix in translation_map:
232+
return f"diffusion_model.blocks.{block_idx}.{translation_map[suffix]}"
233+
return None
234+
235+
203236
def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
204237
"""
205238
Merges weights from a Diffusers-formatted state dict directly
206239
into the kernel of nnx.Linear and nnx.Conv layers.
240+
Assumes scan_layers=False, so NNX paths include block indices.
207241
"""
208242
lora_params = {}
209243
# Parse weights and alphas
210244
for k, v in state_dict.items():
211245
if k.endswith(".alpha"):
212-
module_path_str = k[: -len(".alpha")]
213-
if module_path_str not in lora_params:
214-
lora_params[module_path_str] = {}
215-
lora_params[module_path_str]["alpha"] = jnp.array(v)
246+
key_base = k[:-len(".alpha")]
247+
if key_base not in lora_params:
248+
lora_params[key_base] = {}
249+
lora_params[key_base]["alpha"] = jnp.array(v)
216250
continue
217251

218-
# Try matching diffusers rename format: "some.thing_lora.down.weight"
219252
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
220253
if m:
221-
module_path_str, weight_type = m.group(1), m.group(2)
254+
key_base, weight_type = m.group(1), m.group(2)
222255
else:
223-
# Try matching diffusers format: "some.thing.lora.down.weight"
224256
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
225257
if m:
226-
module_path_str, weight_type = m.group(1), m.group(2)
258+
key_base, weight_type = m.group(1), m.group(2)
227259
else:
228-
# Try matching kohya/lightning format: "some.thing.lora_down.weight"
229260
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
230261
if m:
231-
module_path_str, weight_type = m.group(1), m.group(2).replace("lora_", "")
262+
key_base, weight_type = m.group(1), m.group(2).replace("lora_", "")
232263
else:
233264
max_logging.log(f"Could not parse LoRA key: {k}")
234265
continue
235-
if module_path_str not in lora_params:
236-
lora_params[module_path_str] = {}
237-
lora_params[module_path_str][weight_type] = jnp.array(v)
238-
max_logging.log(f"Parsed {len(lora_params)} unique LoRA module keys: {list(lora_params.keys())}")
266+
if key_base not in lora_params:
267+
lora_params[key_base] = {}
268+
lora_params[key_base][weight_type] = jnp.array(v)
269+
max_logging.log(f"Parsed {len(lora_params)} unique LoRA module keys.")
239270

240271
assigned_count = 0
241272
for path, module in nnx.iter_graph(model):
@@ -245,48 +276,36 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
245276

246277
nnx_path_str = ".".join(map(str, path))
247278
max_logging.log(f"Checking NNX layer: {nnx_path_str}")
279+
lora_key = _translate_nnx_path_to_lora_key(nnx_path_str)
248280

249-
matched_key = None
250-
if nnx_path_str in lora_params:
251-
matched_key = nnx_path_str
252-
else:
253-
# Fallback: check if any param key is a suffix of nnx path
254-
for k in lora_params:
255-
if nnx_path_str.endswith(k):
256-
matched_key = k
257-
max_logging.log(f"NNX path '{nnx_path_str}' matched LoRA key '{k}' via suffix.")
258-
break
259-
max_logging.log(f"Layer: {nnx_path_str}, Matched LoRA key: {matched_key}")
260-
261-
if matched_key and matched_key in lora_params:
262-
weights = lora_params[matched_key]
281+
if lora_key and lora_key in lora_params:
282+
max_logging.log(f"NNX layer '{nnx_path_str}' matched LoRA key '{lora_key}'")
283+
weights = lora_params[lora_key]
263284
if "down" in weights and "up" in weights:
264285
if isinstance(module, nnx.Linear):
265-
down_w = weights["down"] # (rank, in_features)
266-
up_w = weights["up"] # (out_features_flat, rank)
286+
down_w, up_w = weights["down"], weights["up"]
267287
rank = down_w.shape[0]
268288
alpha = weights.get("alpha", rank)
269289
current_scale = scale * alpha / rank
270-
# delta = A@B = down.T @ up.T
271290
delta = (down_w.T @ up_w.T).reshape(module.kernel.shape)
272291
module.kernel.value += delta * current_scale
273292
assigned_count +=1
274293
elif isinstance(module, nnx.Conv):
275-
if module.kernel_size == (1, 1):
276-
down_w = weights["down"] # (1,1,in_c,rank)
277-
up_w = weights["up"] # (1,1,rank,out_c)
294+
if module.kernel_size == (1, 1):
295+
down_w, up_w = weights["down"], weights["up"]
278296
rank = down_w.shape[-1]
279297
alpha = weights.get("alpha", rank)
280298
current_scale = scale * alpha / rank
281-
# delta = down @ up for channel dimension
282299
delta = (jnp.squeeze(down_w) @ jnp.squeeze(up_w)).reshape(module.kernel.shape)
283300
module.kernel.value += delta * current_scale
284301
assigned_count += 1
285-
else:
286-
raise NotImplementedError(
287-
f"Merging LoRA weights for Conv layer {matched_key} "
288-
f"with kernel_size {module.kernel_size} > 1 is not supported."
289-
)
302+
else:
303+
raise NotImplementedError(f"Conv merge only for 1x1 kernels, got {module.kernel_size}")
290304
else:
291-
max_logging.log(f"LoRA weights for {matched_key} incomplete.")
305+
max_logging.warning(f"LoRA weights for {lora_key} incomplete: missing down or up weights.")
306+
elif lora_key:
307+
max_logging.warning(f"NNX layer '{nnx_path_str}' translated to '{lora_key}' but key not in lora_params.")
308+
else:
309+
max_logging.debug(f"NNX layer '{nnx_path_str}' could not be translated to a LoRA key.")
310+
292311
max_logging.log(f"Merged weights into {assigned_count} layers in {type(model).__name__}.")

0 commit comments

Comments
 (0)