Skip to content

Commit b9ed4d8

Browse files
committed
Fix
1 parent 0b40036 commit b9ed4d8

1 file changed

Lines changed: 62 additions & 2 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,34 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
224224

225225
assigned_count = 0
226226
for path, module in nnx.iter_graph(model):
227-
if not isinstance(module, (nnx.Linear, nnx.Conv)): continue
227+
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): continue
228228

229229
nnx_path_str = ".".join(map(str, path))
230230
lora_key = translate_fn(nnx_path_str) if translate_fn else None
231231

232232
if lora_key and lora_key in lora_params:
233233
weights = lora_params[lora_key]
234+
235+
# Handle Embeddings
236+
if isinstance(module, nnx.Embed):
237+
if "diff" in weights and hasattr(module, 'embedding'):
238+
module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype)
239+
assigned_count += 1
240+
continue
241+
# Handle Norms
242+
elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)):
243+
scale_diff = weights.get("diff", None)
244+
bias_diff = weights.get("diff_b", None)
245+
updated = False
246+
if scale_diff is not None and hasattr(module, 'scale') and module.scale is not None:
247+
module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype)
248+
updated = True
249+
if bias_diff is not None and isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None:
250+
module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype)
251+
updated = True
252+
if updated:
253+
assigned_count += 1
254+
continue
234255

235256
# Prepare LoRA terms
236257
down_w, up_w, current_scale = None, None, None
@@ -298,14 +319,53 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
298319

299320
assigned_count = 0
300321
for path, module in nnx.iter_graph(model):
301-
if not isinstance(module, (nnx.Linear, nnx.Conv)): continue
322+
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): continue
302323

303324
nnx_path_str = ".".join(map(str, path))
304325
lora_key_template = translate_fn(nnx_path_str) if translate_fn else None
305326

306327
if not lora_key_template:
307328
continue
308329

330+
# Handle Scanned Embeddings
331+
if isinstance(module, nnx.Embed) and hasattr(module, 'embedding') and module.embedding.ndim > 2:
332+
num_layers = module.embedding.shape[0]
333+
embed_diffs_to_add = np.zeros_like(module.embedding.value)
334+
updated = False
335+
for i in range(num_layers):
336+
lora_key = lora_key_template.format(i)
337+
if lora_key in lora_params and "diff" in lora_params[lora_key]:
338+
embed_diffs_to_add[i] = np.array(lora_params[lora_key]["diff"]).reshape(module.embedding.shape[1:])
339+
updated = True
340+
if updated:
341+
module.embedding.value += embed_diffs_to_add.astype(module.embedding.dtype)
342+
assigned_count += 1
343+
continue
344+
345+
# Handle Scanned Norms
346+
if isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)) and hasattr(module, 'scale') and module.scale is not None and module.scale.ndim > 1:
347+
num_layers = module.scale.shape[0]
348+
scale_diffs_to_add = np.zeros_like(module.scale.value)
349+
bias_diffs_to_add = np.zeros_like(module.bias.value) if isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None else None
350+
updated_scale, updated_bias = False, False
351+
for i in range(num_layers):
352+
lora_key = lora_key_template.format(i)
353+
if lora_key in lora_params:
354+
weights = lora_params[lora_key]
355+
if "diff" in weights:
356+
scale_diffs_to_add[i] = np.array(weights["diff"]).reshape(module.scale.shape[1:])
357+
updated_scale = True
358+
if "diff_b" in weights and bias_diffs_to_add is not None:
359+
bias_diffs_to_add[i] = np.array(weights["diff_b"]).reshape(module.bias.shape[1:])
360+
updated_bias = True
361+
if updated_scale:
362+
module.scale.value += scale_diffs_to_add.astype(module.scale.dtype)
363+
if updated_bias and bias_diffs_to_add is not None:
364+
module.bias.value += bias_diffs_to_add.astype(module.bias.dtype)
365+
if updated_scale or updated_bias:
366+
assigned_count += 1
367+
continue
368+
309369
is_linear = isinstance(module, nnx.Linear) and module.kernel.ndim == 3
310370
is_conv = isinstance(module, nnx.Conv) and module.kernel.ndim == 5
311371

0 commit comments

Comments
 (0)