Skip to content

Commit 0d221eb

Browse files
committed
Fix
1 parent eb330d2 commit 0d221eb

1 file changed

Lines changed: 71 additions & 46 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 71 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import re
1919
import torch
2020
import jax
21+
import numpy as np
2122
from jax import dlpack
2223
import jax.numpy as jnp
2324
from flax import nnx
@@ -156,8 +157,8 @@ def __call__(self, x):
156157

157158
def _to_jax_array(v):
158159
if isinstance(v, torch.Tensor):
159-
return jax.device_put(dlpack.from_dlpack(v))
160-
return jax.device_put(jnp.array(v))
160+
return dlpack.from_dlpack(v)
161+
return jnp.array(v)
161162

162163
def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=None):
163164
"""
@@ -211,7 +212,9 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
211212
alpha = weights.get("alpha", rank)
212213
current_scale = scale * alpha / rank
213214
delta = (down_w.T @ up_w.T).reshape(module.kernel.shape)
214-
module.kernel.value += delta * current_scale
215+
update = delta * current_scale
216+
update = jax.device_put(update, module.kernel.value.sharding)
217+
module.kernel.value += update
215218
assigned_count +=1
216219
elif isinstance(module, nnx.Conv):
217220
if module.kernel_size == (1, 1):
@@ -220,7 +223,9 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
220223
alpha = weights.get("alpha", rank)
221224
current_scale = scale * alpha / rank
222225
delta = (jnp.squeeze(down_w) @ jnp.squeeze(up_w)).reshape(module.kernel.shape)
223-
module.kernel.value += delta * current_scale
226+
update = delta * current_scale
227+
update = jax.device_put(update, module.kernel.value.sharding)
228+
module.kernel.value += update
224229
assigned_count += 1
225230
else:
226231
raise NotImplementedError(f"Conv merge only for 1x1 kernels, got {module.kernel_size}")
@@ -240,34 +245,31 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
240245
into the kernel of nnx.Linear and nnx.Conv layers.
241246
Assumes scan_layers=True, so weights are stacked if layers are scanned
242247
(e.g. kernel.ndim=3 for Linear).
248+
Optimized: Accumulates updates on CPU first, then performs a single device_put.
243249
"""
244250
lora_params = {}
245-
# Parse weights and alphas
251+
# --- Parsing Logic ---
246252
for k, v in state_dict.items():
247253
if k.endswith(".alpha"):
248254
key_base = k[:-len(".alpha")]
249-
if key_base not in lora_params:
250-
lora_params[key_base] = {}
255+
if key_base not in lora_params: lora_params[key_base] = {}
251256
lora_params[key_base]["alpha"] = _to_jax_array(v)
252257
continue
253258

254259
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
260+
if not m:
261+
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
262+
if not m:
263+
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
264+
255265
if m:
256-
key_base, weight_type = m.group(1), m.group(2)
266+
key_base, weight_type = m.group(1), m.group(2).replace("lora_", "")
267+
if key_base not in lora_params: lora_params[key_base] = {}
268+
lora_params[key_base][weight_type] = _to_jax_array(v)
257269
else:
258-
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
259-
if m:
260-
key_base, weight_type = m.group(1), m.group(2)
261-
else:
262-
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
263-
if m:
264-
key_base, weight_type = m.group(1), m.group(2).replace("lora_", "")
265-
else:
266-
max_logging.log(f"Could not parse LoRA key: {k}")
267-
continue
268-
if key_base not in lora_params:
269-
lora_params[key_base] = {}
270-
lora_params[key_base][weight_type] = _to_jax_array(v)
270+
max_logging.log(f"Could not parse LoRA key: {k}")
271+
continue
272+
271273
max_logging.log(f"Parsed {len(lora_params)} unique LoRA module keys for scanned merge.")
272274

273275
assigned_count = 0
@@ -277,69 +279,92 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
277279

278280
nnx_path_str = ".".join(map(str, path))
279281

280-
# Handle scanned Linear layers
282+
# --- Handle Scanned Linear (NDIM=3) ---
281283
if isinstance(module, nnx.Linear) and module.kernel.ndim == 3:
282284
lora_key_template = translate_fn(nnx_path_str) if translate_fn else None
283285

284286
if lora_key_template:
285287
num_layers, in_features, out_features = module.kernel.shape
286-
kernel_value_updated = module.kernel.value
287-
lora_found_in_module = False
288+
289+
# 1. Create a zero-filled buffer on CPU for float32 accumulation
290+
cpu_delta_buffer = np.zeros((num_layers, in_features, out_features), dtype=np.float32)
291+
292+
lora_found = False
288293
for i in range(num_layers):
289294
lora_key = lora_key_template.format(i)
290295
if lora_key in lora_params and "down" in lora_params[lora_key] and "up" in lora_params[lora_key]:
291296
weights = lora_params[lora_key]
292-
down_w, up_w = weights["down"], weights["up"]
297+
# Pull weights to CPU/Numpy for cheap calculation
298+
down_w = np.array(weights["down"])
299+
up_w = np.array(weights["up"])
300+
293301
rank = down_w.shape[0]
294-
alpha = weights.get("alpha", rank)
302+
alpha = float(weights.get("alpha", rank)) # ensure scalar
295303
current_scale = scale * alpha / rank
304+
305+
# Compute Delta on CPU
296306
delta_i = (down_w.T @ up_w.T).reshape(in_features, out_features) * current_scale
297-
kernel_value_updated = kernel_value_updated.at[i].add(delta_i)
298-
lora_found_in_module = True
299-
300-
if lora_found_in_module:
301-
module.kernel.value = kernel_value_updated
307+
308+
# Accumulate in buffer
309+
cpu_delta_buffer[i] += delta_i
310+
lora_found = True
311+
312+
if lora_found:
313+
# 2. Single Transfer: Move buffer to TPU with correct sharding and dtype
314+
sharded_delta = jax.device_put(
315+
jnp.array(cpu_delta_buffer, dtype=module.kernel.dtype),
316+
module.kernel.value.sharding
317+
)
318+
# 3. In-place add
319+
module.kernel.value += sharded_delta
302320
assigned_count += 1
303321
else:
304-
max_logging.log(f"Scanned layer {nnx_path_str} matched template but no LoRA weights found for any block.")
322+
max_logging.log(f"Scanned layer {nnx_path_str} matched template but no LoRA weights found.")
305323
else:
306324
max_logging.log(f"Scanned NNX layer '{nnx_path_str}' could not be translated to a LoRA key template.")
307325

308-
# Handle scanned Conv layers (ndim=5)
326+
# --- Handle Scanned Conv (NDIM=5) ---
309327
elif isinstance(module, nnx.Conv) and module.kernel.ndim == 5:
310328
if module.kernel_size != (1, 1):
311329
max_logging.log(f"Skipping merge for scanned Conv layer {nnx_path_str} with kernel size {module.kernel_size}, only 1x1 is supported for merging.")
312330
continue
313-
331+
314332
lora_key_template = translate_fn(nnx_path_str) if translate_fn else None
315333
if lora_key_template:
316334
num_layers, _, _, in_features, out_features = module.kernel.shape
317-
kernel_value_updated = module.kernel.value
318-
lora_found_in_module = False
335+
cpu_delta_buffer = np.zeros(module.kernel.shape, dtype=np.float32)
336+
lora_found = False
337+
319338
for i in range(num_layers):
320339
lora_key = lora_key_template.format(i)
321340
if lora_key in lora_params and "down" in lora_params[lora_key] and "up" in lora_params[lora_key]:
322341
weights = lora_params[lora_key]
323-
down_w, up_w = weights["down"], weights["up"]
342+
down_w = np.array(weights["down"])
343+
up_w = np.array(weights["up"])
324344

325345
if down_w.ndim == 4:
326-
down_w = jnp.squeeze(down_w)
346+
down_w = np.squeeze(down_w)
327347
if up_w.ndim == 4:
328-
up_w = jnp.squeeze(up_w)
348+
up_w = np.squeeze(up_w)
329349

330350
rank = down_w.shape[0]
331-
alpha = weights.get("alpha", rank)
351+
alpha = float(weights.get("alpha", rank))
332352
current_scale = scale * alpha / rank
333353
delta_i = (down_w.T @ up_w.T).reshape(1, 1, in_features, out_features) * current_scale
334-
kernel_value_updated = kernel_value_updated.at[i].add(delta_i)
335-
lora_found_in_module = True
336-
337-
if lora_found_in_module:
338-
module.kernel.value = kernel_value_updated
354+
cpu_delta_buffer[i] += delta_i
355+
lora_found = True
356+
357+
if lora_found:
358+
sharded_delta = jax.device_put(
359+
jnp.array(cpu_delta_buffer, dtype=module.kernel.dtype),
360+
module.kernel.value.sharding
361+
)
362+
module.kernel.value += sharded_delta
339363
assigned_count += 1
340364
else:
341-
max_logging.log(f"Scanned 1x1 Conv layer {nnx_path_str} matched template but no LoRA weights found for any block.")
365+
max_logging.log(f"Scanned 1x1 Conv layer {nnx_path_str} matched template but no LoRA weights found.")
342366
else:
343367
max_logging.log(f"Scanned 1x1 Conv layer '{nnx_path_str}' could not be translated to a LoRA key template.")
344368

369+
345370
max_logging.log(f"Merged weights into {assigned_count} scanned layers in {type(model).__name__}.")

0 commit comments

Comments
 (0)