Skip to content

Commit 068d657

Browse files
committed
test
1 parent b57f376 commit 068d657

1 file changed

Lines changed: 26 additions & 10 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def _merge_lora_layer(module, weights, scale):
169169
if bias_diff is not None and isinstance(module, nnx.LayerNorm) and hasattr(module, "bias") and module.bias is not None:
170170
module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype)
171171
updated = True
172+
elif isinstance(module, nnx.Param):
173+
if "diff" in weights:
174+
module.value += np.array(weights["diff"]).reshape(module.shape).astype(module.dtype)
175+
updated = True
172176
elif isinstance(module, (nnx.Linear, nnx.Conv)):
173177
# Prepare LoRA terms
174178
down_w, up_w, current_scale = None, None, None
@@ -252,7 +256,7 @@ def merge_lora(model: nnx.Module, state_dict: dict, rank: int, scale: float, tra
252256

253257
assigned_count = 0
254258
for path, module in nnx.iter_graph(model):
255-
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)):
259+
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed, nnx.Param)):
256260
continue
257261

258262
nnx_path_str = ".".join(map(str, path))
@@ -285,23 +289,19 @@ def merge_lora_for_scanned(
285289

286290
assigned_count = 0
287291
for path, module in nnx.iter_graph(model):
292+
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed, nnx.Param)):
293+
continue
288294

289295
nnx_path_str = ".".join(map(str, path))
290296
lora_key_template = translate_fn(nnx_path_str) if translate_fn else None
291-
max_logging.log(f"nnx_path_str: {nnx_path_str}")
292-
if 'scale_shift_table' in nnx_path_str:
297+
298+
if 'adaln_scale_shift_table' in nnx_path_str:
293299
max_logging.log(f"adaln_scale_shift_table: {nnx_path_str}")
294300
max_logging.log(f"lora_key_template: {lora_key_template}")
295-
max_logging.log(f"Module: {module}")
301+
max_logging.log(f"Module ndim: {module.ndim}")
296302
max_logging.log(f"Module Type: {type(module)}")
297303
continue
298304

299-
if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)):
300-
continue
301-
302-
nnx_path_str = ".".join(map(str, path))
303-
lora_key_template = translate_fn(nnx_path_str) if translate_fn else None
304-
305305
if not lora_key_template:
306306
continue
307307

@@ -315,6 +315,8 @@ def merge_lora_for_scanned(
315315
is_scanned = module.kernel.ndim == 3
316316
elif isinstance(module, nnx.Conv):
317317
is_scanned = module.kernel.ndim == 5
318+
elif isinstance(module, nnx.Param):
319+
is_scanned = module.ndim > 1 # Heuristic for scanned Param
318320

319321
if not is_scanned:
320322
lora_key = lora_key_template
@@ -366,6 +368,20 @@ def merge_lora_for_scanned(
366368
module.bias.value += bias_diffs_to_add.astype(module.bias.dtype)
367369
if updated_scale or updated_bias:
368370
assigned_count += 1
371+
elif isinstance(module, nnx.Param):
372+
num_layers = module.shape[0]
373+
param_diffs_to_add = np.zeros_like(module.value)
374+
updated = False
375+
for i in range(num_layers):
376+
lora_key = lora_key_template.format(i)
377+
if lora_key in lora_params:
378+
matched_keys.add(lora_key)
379+
if "diff" in lora_params[lora_key]:
380+
param_diffs_to_add[i] = np.array(lora_params[lora_key]["diff"]).reshape(module.shape[1:])
381+
updated = True
382+
if updated:
383+
module.value += param_diffs_to_add.astype(module.dtype)
384+
assigned_count += 1
369385
elif isinstance(module, (nnx.Linear, nnx.Conv)):
370386
is_linear = isinstance(module, nnx.Linear)
371387
is_conv = isinstance(module, nnx.Conv)

0 commit comments

Comments
 (0)