Skip to content

Commit 5d7a8c2

Browse files
committed
Fix
1 parent cc72fb6 commit 5d7a8c2

1 file changed

Lines changed: 59 additions & 23 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
225225
assigned_count = 0
226226
for path, module in nnx.iter_graph(model):
227227
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)):
228-
max_logging.log(f"Skipping non-supported module type: {module}")
229228
continue
230229

231230
nnx_path_str = ".".join(map(str, path))
@@ -234,6 +233,10 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
234233
if lora_key and lora_key in lora_params:
235234
weights = lora_params[lora_key]
236235

236+
is_conv_kxk_locon = False
237+
if isinstance(module, nnx.Conv) and module.kernel_size != (1,1) and "down" in weights and "up" in weights:
238+
is_conv_kxk_locon = True
239+
237240
# Handle Embeddings
238241
if isinstance(module, nnx.Embed):
239242
if "diff" in weights and hasattr(module, 'embedding'):
@@ -257,20 +260,17 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
257260

258261
# Prepare LoRA terms
259262
down_w, up_w, current_scale = None, None, None
260-
if "down" in weights and "up" in weights:
261-
if isinstance(module, nnx.Conv) and module.kernel_size != (1, 1):
262-
max_logging.log(f"Skipping LoRA merge for non-1x1 Conv: {lora_key}")
263-
else:
264-
down_w, up_w = weights["down"], weights["up"]
265-
down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert
266-
267-
# Squeeze dimensions if needed (Conv 1x1 or Linear)
268-
if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1):
269-
down_w, up_w = np.squeeze(down_w), np.squeeze(up_w)
263+
if "down" in weights and "up" in weights and not is_conv_kxk_locon:
264+
down_w, up_w = weights["down"], weights["up"]
265+
down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert
266+
267+
# Squeeze dimensions if needed (Conv 1x1 or Linear)
268+
if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1):
269+
down_w, up_w = np.squeeze(down_w), np.squeeze(up_w)
270270

271-
rank = down_w.shape[0] if down_w.ndim > 0 else 0
272-
alpha = float(weights.get("alpha", rank))
273-
current_scale = scale * alpha / rank
271+
rank = down_w.shape[0] if down_w.ndim > 0 else 0
272+
alpha = float(weights.get("alpha", rank))
273+
current_scale = scale * alpha / rank
274274

275275
# Prepare Diff terms
276276
w_diff = weights.get("diff", None)
@@ -288,6 +288,25 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
288288
w_diff = w_diff.transpose((1,0))
289289
if b_diff is not None: b_diff = np.array(b_diff)
290290

291+
# If LoCON, compute delta and add to w_diff
292+
if is_conv_kxk_locon:
293+
dw, uw = np.array(weights['down']), np.array(weights['up'])
294+
rank, in_c, *k_dims = dw.shape
295+
out_c = uw.shape[0]
296+
alpha = float(weights.get("alpha", rank))
297+
298+
delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims)
299+
300+
# Transpose to flax
301+
if delta_pt.ndim == 5: delta_fx = delta_pt.transpose((2,3,4,1,0))
302+
else: delta_fx = delta_pt.transpose((2,3,1,0))
303+
304+
lora_delta = delta_fx * (scale * alpha / rank)
305+
if w_diff is None:
306+
w_diff = lora_delta.astype(np.float32)
307+
else:
308+
w_diff += lora_delta.astype(w_diff.dtype)
309+
291310
# Check for Bias existence
292311
bias_val = module.bias.value if module.bias is not None else None
293312

@@ -322,7 +341,6 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
322341
assigned_count = 0
323342
for path, module in nnx.iter_graph(model):
324343
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)):
325-
max_logging.log(f"Skipping non-supported module type: {module}")
326344
continue
327345

328346
nnx_path_str = ".".join(map(str, path))
@@ -373,10 +391,11 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
373391
is_linear = isinstance(module, nnx.Linear) and module.kernel.ndim == 3
374392
is_conv = isinstance(module, nnx.Conv) and module.kernel.ndim == 5
375393

394+
is_conv_kxk = isinstance(module, nnx.Conv) and module.kernel_size != (1,1)
395+
376396
if is_linear:
377397
num_layers, in_feat, out_feat = module.kernel.shape
378398
elif is_conv:
379-
if module.kernel_size != (1, 1): continue
380399
num_layers = module.kernel.shape[0]
381400
in_feat, out_feat = module.kernel.shape[3], module.kernel.shape[4]
382401
else:
@@ -412,12 +431,29 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
412431
# --- Fill LoRA ---
413432
if "down" in w:
414433
d, u = np.array(w["down"]), np.array(w["up"])
415-
if d.ndim > 2: d = np.squeeze(d)
416-
if u.ndim > 2: u = np.squeeze(u)
417-
stack_down[i] = d
418-
stack_up[i] = u
419-
stack_alpha[i] = float(w.get("alpha", d.shape[0]))
420-
has_lora = True
434+
alpha = float(w.get("alpha", d.shape[0]))
435+
rank = d.shape[0]
436+
437+
if is_conv_kxk:
438+
# For LoCON kxk, compute delta and merge into stack_w_diff
439+
rank, in_c, *k_dims = d.shape
440+
out_c = u.shape[0]
441+
delta_pt = (u.reshape(out_c, rank) @ d.reshape(rank, -1)).reshape(out_c, in_c, *k_dims)
442+
if delta_pt.ndim == 5: delta_fx = delta_pt.transpose((2,3,4,1,0))
443+
else: delta_fx = delta_pt.transpose((2,3,1,0))
444+
445+
lora_delta = delta_fx * (scale * alpha / rank)
446+
if stack_w_diff is None: stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32)
447+
stack_w_diff[i] += lora_delta.astype(stack_w_diff.dtype)
448+
has_diff = True # Mark as having diff because we merged LoRA into w_diff
449+
else:
450+
# For Linear or 1x1 Conv, prepare for JIT
451+
if d.ndim > 2: d = np.squeeze(d)
452+
if u.ndim > 2: u = np.squeeze(u)
453+
stack_down[i] = d
454+
stack_up[i] = u
455+
stack_alpha[i] = alpha
456+
has_lora = True
421457

422458
# --- Fill Weight Diff ---
423459
if "diff" in w:
@@ -433,7 +469,7 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
433469
elif is_linear and wd.ndim == 2:
434470
wd = wd.transpose((1,0))
435471

436-
stack_w_diff[i] = wd
472+
stack_w_diff[i] += wd
437473
has_diff = True
438474

439475
# --- Fill Bias Diff ---

0 commit comments

Comments
 (0)