Skip to content

Commit d52e25d

Browse files
committed
Fix
1 parent 33c831e commit d52e25d

1 file changed

Lines changed: 67 additions & 46 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 67 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -206,66 +206,87 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
206206
into the kernel of nnx.Linear and nnx.Conv layers.
207207
"""
208208
lora_params = {}
209+
# Parse weights and alphas
209210
for k, v in state_dict.items():
210-
# Try matching diffusers rename format: "some.thing_lora.down.weight"
211-
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
212-
if m:
213-
module_path_str, weight_type = m.group(1), m.group(2)
214-
else:
215-
# Try matching diffusers format: "some.thing.lora.down.weight"
216-
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
217-
if m:
218-
module_path_str, weight_type = m.group(1), m.group(2)
219-
else:
220-
# Try matching kohya/lightning format: "some.thing.lora_down.weight"
221-
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
222-
if m:
223-
module_path_str, weight_type = m.group(1), m.group(2).replace("lora_", "")
224-
else:
225-
max_logging.log(f"Could not parse LoRA key: {k}")
226-
continue
227-
228-
if module_path_str not in lora_params:
229-
lora_params[module_path_str] = {}
230-
lora_params[module_path_str][weight_type] = jnp.array(v)
211+
if k.endswith(".alpha"):
212+
module_path_str = k[: -len(".alpha")]
213+
if module_path_str not in lora_params:
214+
lora_params[module_path_str] = {}
215+
lora_params[module_path_str]["alpha"] = jnp.array(v)
216+
continue
217+
218+
# Try matching diffusers rename format: "some.thing_lora.down.weight"
219+
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
220+
if m:
221+
module_path_str, weight_type = m.group(1), m.group(2)
222+
else:
223+
# Try matching diffusers format: "some.thing.lora.down.weight"
224+
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
225+
if m:
226+
module_path_str, weight_type = m.group(1), m.group(2)
227+
else:
228+
# Try matching kohya/lightning format: "some.thing.lora_down.weight"
229+
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
230+
if m:
231+
module_path_str, weight_type = m.group(1), m.group(2).replace("lora_", "")
232+
else:
233+
max_logging.log(f"Could not parse LoRA key: {k}")
234+
continue
235+
if module_path_str not in lora_params:
236+
lora_params[module_path_str] = {}
237+
lora_params[module_path_str][weight_type] = jnp.array(v)
238+
max_logging.log(f"Parsed {len(lora_params)} unique LoRA module keys: {list(lora_params.keys())}")
231239

232240
assigned_count = 0
233241
for path, module in nnx.iter_graph(model):
242+
if not isinstance(module, (nnx.Linear, nnx.Conv)):
243+
max_logging.log(f"Skipping non-Linear/Conv layer: {module}")
244+
continue
245+
234246
nnx_path_str = ".".join(map(str, path))
247+
max_logging.log(f"Checking NNX layer: {nnx_path_str}")
235248

236249
matched_key = None
237250
if nnx_path_str in lora_params:
238251
matched_key = nnx_path_str
239252
else:
240-
# Fallback: check if any param key matches end of nnx path
253+
# Fallback: check if any param key is a suffix of nnx path
241254
for k in lora_params:
242255
if nnx_path_str.endswith(k):
243256
matched_key = k
257+
max_logging.log(f"NNX path '{nnx_path_str}' matched LoRA key '{k}' via suffix.")
244258
break
259+
max_logging.log(f"Layer: {nnx_path_str}, Matched LoRA key: {matched_key}")
245260

246261
if matched_key and matched_key in lora_params:
247-
weights = lora_params[matched_key]
248-
if "down" in weights and "up" in weights:
249-
if isinstance(module, nnx.Linear):
250-
down_w = weights["down"] # (rank, in_features)
251-
up_w = weights["up"] # (out_features_flat, rank)
252-
# delta = A@B = down.T @ up.T
253-
delta = (down_w.T @ up_w.T).reshape(module.kernel.shape)
254-
module.kernel.value += delta * scale
255-
assigned_count +=1
256-
elif isinstance(module, nnx.Conv):
257-
if module.kernel_size == (1, 1):
258-
down_w = weights["down"] # (1,1,in_c,rank)
259-
up_w = weights["up"] # (1,1,rank,out_c)
260-
# delta = down @ up for channel dimension
261-
delta = (jnp.squeeze(down_w) @ jnp.squeeze(up_w)).reshape(module.kernel.shape)
262-
module.kernel.value += delta * scale
263-
assigned_count += 1
264-
else:
265-
raise NotImplementedError(
266-
f"Merging LoRA weights for Conv layer {matched_key} "
267-
f"with kernel_size {module.kernel_size} > 1 is not supported."
268-
)
269-
else:
270-
max_logging.log(f"LoRA weights for {matched_key} incomplete.")
262+
weights = lora_params[matched_key]
263+
if "down" in weights and "up" in weights:
264+
if isinstance(module, nnx.Linear):
265+
down_w = weights["down"] # (rank, in_features)
266+
up_w = weights["up"] # (out_features_flat, rank)
267+
rank = down_w.shape[0]
268+
alpha = weights.get("alpha", rank)
269+
current_scale = scale * alpha / rank
270+
# delta = A@B = down.T @ up.T
271+
delta = (down_w.T @ up_w.T).reshape(module.kernel.shape)
272+
module.kernel.value += delta * current_scale
273+
assigned_count +=1
274+
elif isinstance(module, nnx.Conv):
275+
if module.kernel_size == (1, 1):
276+
down_w = weights["down"] # (1,1,in_c,rank)
277+
up_w = weights["up"] # (1,1,rank,out_c)
278+
rank = down_w.shape[-1]
279+
alpha = weights.get("alpha", rank)
280+
current_scale = scale * alpha / rank
281+
# delta = down @ up for channel dimension
282+
delta = (jnp.squeeze(down_w) @ jnp.squeeze(up_w)).reshape(module.kernel.shape)
283+
module.kernel.value += delta * current_scale
284+
assigned_count += 1
285+
else:
286+
raise NotImplementedError(
287+
f"Merging LoRA weights for Conv layer {matched_key} "
288+
f"with kernel_size {module.kernel_size} > 1 is not supported."
289+
)
290+
else:
291+
max_logging.log(f"LoRA weights for {matched_key} incomplete.")
271292
max_logging.log(f"Merged weights into {assigned_count} layers in {type(model).__name__}.")

0 commit comments

Comments
 (0)