1818import re
1919import torch
2020import jax
21+ import numpy as np
2122from jax import dlpack
2223import jax .numpy as jnp
2324from flax import nnx
@@ -156,8 +157,8 @@ def __call__(self, x):
156157
157158def _to_jax_array (v ):
158159 if isinstance (v , torch .Tensor ):
159- return jax . device_put ( dlpack .from_dlpack (v ) )
160- return jax . device_put ( jnp .array (v ) )
160+ return dlpack .from_dlpack (v )
161+ return jnp .array (v )
161162
162163def merge_lora (model : nnx .Module , state_dict : dict , scale : float , translate_fn = None ):
163164 """
@@ -211,7 +212,9 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
211212 alpha = weights .get ("alpha" , rank )
212213 current_scale = scale * alpha / rank
213214 delta = (down_w .T @ up_w .T ).reshape (module .kernel .shape )
214- module .kernel .value += delta * current_scale
215+ update = delta * current_scale
216+ update = jax .device_put (update , module .kernel .value .sharding )
217+ module .kernel .value += update
215218 assigned_count += 1
216219 elif isinstance (module , nnx .Conv ):
217220 if module .kernel_size == (1 , 1 ):
@@ -220,7 +223,9 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
220223 alpha = weights .get ("alpha" , rank )
221224 current_scale = scale * alpha / rank
222225 delta = (jnp .squeeze (down_w ) @ jnp .squeeze (up_w )).reshape (module .kernel .shape )
223- module .kernel .value += delta * current_scale
226+ update = delta * current_scale
227+ update = jax .device_put (update , module .kernel .value .sharding )
228+ module .kernel .value += update
224229 assigned_count += 1
225230 else :
226231 raise NotImplementedError (f"Conv merge only for 1x1 kernels, got { module .kernel_size } " )
@@ -240,34 +245,31 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
240245 into the kernel of nnx.Linear and nnx.Conv layers.
241246 Assumes scan_layers=True, so weights are stacked if layers are scanned
242247 (e.g. kernel.ndim=3 for Linear).
248+ Optimized: Accumulates updates on CPU first, then performs a single device_put.
243249 """
244250 lora_params = {}
245- # Parse weights and alphas
251+ # --- Parsing Logic ---
246252 for k , v in state_dict .items ():
247253 if k .endswith (".alpha" ):
248254 key_base = k [:- len (".alpha" )]
249- if key_base not in lora_params :
250- lora_params [key_base ] = {}
255+ if key_base not in lora_params : lora_params [key_base ] = {}
251256 lora_params [key_base ]["alpha" ] = _to_jax_array (v )
252257 continue
253258
254259 m = re .match (r"^(.*?)_lora\.(down|up)\.weight$" , k )
260+ if not m :
261+ m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
262+ if not m :
263+ m = re .match (r"^(.*?)\.(lora_down|lora_up)\.weight$" , k )
264+
255265 if m :
256- key_base , weight_type = m .group (1 ), m .group (2 )
266+ key_base , weight_type = m .group (1 ), m .group (2 ).replace ("lora_" , "" )
267+ if key_base not in lora_params : lora_params [key_base ] = {}
268+ lora_params [key_base ][weight_type ] = _to_jax_array (v )
257269 else :
258- m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
259- if m :
260- key_base , weight_type = m .group (1 ), m .group (2 )
261- else :
262- m = re .match (r"^(.*?)\.(lora_down|lora_up)\.weight$" , k )
263- if m :
264- key_base , weight_type = m .group (1 ), m .group (2 ).replace ("lora_" , "" )
265- else :
266- max_logging .log (f"Could not parse LoRA key: { k } " )
267- continue
268- if key_base not in lora_params :
269- lora_params [key_base ] = {}
270- lora_params [key_base ][weight_type ] = _to_jax_array (v )
270+ max_logging .log (f"Could not parse LoRA key: { k } " )
271+ continue
272+
271273 max_logging .log (f"Parsed { len (lora_params )} unique LoRA module keys for scanned merge." )
272274
273275 assigned_count = 0
@@ -277,69 +279,92 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
277279
278280 nnx_path_str = "." .join (map (str , path ))
279281
280- # Handle scanned Linear layers
282+ # --- Handle Scanned Linear (NDIM=3) ---
281283 if isinstance (module , nnx .Linear ) and module .kernel .ndim == 3 :
282284 lora_key_template = translate_fn (nnx_path_str ) if translate_fn else None
283285
284286 if lora_key_template :
285287 num_layers , in_features , out_features = module .kernel .shape
286- kernel_value_updated = module .kernel .value
287- lora_found_in_module = False
288+
289+ # 1. Create a zero-filled buffer on CPU for float32 accumulation
290+ cpu_delta_buffer = np .zeros ((num_layers , in_features , out_features ), dtype = np .float32 )
291+
292+ lora_found = False
288293 for i in range (num_layers ):
289294 lora_key = lora_key_template .format (i )
290295 if lora_key in lora_params and "down" in lora_params [lora_key ] and "up" in lora_params [lora_key ]:
291296 weights = lora_params [lora_key ]
292- down_w , up_w = weights ["down" ], weights ["up" ]
297+ # Pull weights to CPU/Numpy for cheap calculation
298+ down_w = np .array (weights ["down" ])
299+ up_w = np .array (weights ["up" ])
300+
293301 rank = down_w .shape [0 ]
294- alpha = weights .get ("alpha" , rank )
302+ alpha = float ( weights .get ("alpha" , rank )) # ensure scalar
295303 current_scale = scale * alpha / rank
304+
305+ # Compute Delta on CPU
296306 delta_i = (down_w .T @ up_w .T ).reshape (in_features , out_features ) * current_scale
297- kernel_value_updated = kernel_value_updated .at [i ].add (delta_i )
298- lora_found_in_module = True
299-
300- if lora_found_in_module :
301- module .kernel .value = kernel_value_updated
307+
308+ # Accumulate in buffer
309+ cpu_delta_buffer [i ] += delta_i
310+ lora_found = True
311+
312+ if lora_found :
313+ # 2. Single Transfer: Move buffer to TPU with correct sharding and dtype
314+ sharded_delta = jax .device_put (
315+ jnp .array (cpu_delta_buffer , dtype = module .kernel .dtype ),
316+ module .kernel .value .sharding
317+ )
318+ # 3. In-place add
319+ module .kernel .value += sharded_delta
302320 assigned_count += 1
303321 else :
304- max_logging .log (f"Scanned layer { nnx_path_str } matched template but no LoRA weights found for any block ." )
322+ max_logging .log (f"Scanned layer { nnx_path_str } matched template but no LoRA weights found." )
305323 else :
306324 max_logging .log (f"Scanned NNX layer '{ nnx_path_str } ' could not be translated to a LoRA key template." )
307325
308- # Handle scanned Conv layers (ndim =5)
326+ # --- Handle Scanned Conv (NDIM =5) ---
309327 elif isinstance (module , nnx .Conv ) and module .kernel .ndim == 5 :
310328 if module .kernel_size != (1 , 1 ):
311329 max_logging .log (f"Skipping merge for scanned Conv layer { nnx_path_str } with kernel size { module .kernel_size } , only 1x1 is supported for merging." )
312330 continue
313-
331+
314332 lora_key_template = translate_fn (nnx_path_str ) if translate_fn else None
315333 if lora_key_template :
316334 num_layers , _ , _ , in_features , out_features = module .kernel .shape
317- kernel_value_updated = module .kernel .value
318- lora_found_in_module = False
335+ cpu_delta_buffer = np .zeros (module .kernel .shape , dtype = np .float32 )
336+ lora_found = False
337+
319338 for i in range (num_layers ):
320339 lora_key = lora_key_template .format (i )
321340 if lora_key in lora_params and "down" in lora_params [lora_key ] and "up" in lora_params [lora_key ]:
322341 weights = lora_params [lora_key ]
323- down_w , up_w = weights ["down" ], weights ["up" ]
342+ down_w = np .array (weights ["down" ])
343+ up_w = np .array (weights ["up" ])
324344
325345 if down_w .ndim == 4 :
326- down_w = jnp .squeeze (down_w )
346+ down_w = np .squeeze (down_w )
327347 if up_w .ndim == 4 :
328- up_w = jnp .squeeze (up_w )
348+ up_w = np .squeeze (up_w )
329349
330350 rank = down_w .shape [0 ]
331- alpha = weights .get ("alpha" , rank )
351+ alpha = float ( weights .get ("alpha" , rank ) )
332352 current_scale = scale * alpha / rank
333353 delta_i = (down_w .T @ up_w .T ).reshape (1 , 1 , in_features , out_features ) * current_scale
334- kernel_value_updated = kernel_value_updated .at [i ].add (delta_i )
335- lora_found_in_module = True
336-
337- if lora_found_in_module :
338- module .kernel .value = kernel_value_updated
354+ cpu_delta_buffer [i ] += delta_i
355+ lora_found = True
356+
357+ if lora_found :
358+ sharded_delta = jax .device_put (
359+ jnp .array (cpu_delta_buffer , dtype = module .kernel .dtype ),
360+ module .kernel .value .sharding
361+ )
362+ module .kernel .value += sharded_delta
339363 assigned_count += 1
340364 else :
341- max_logging .log (f"Scanned 1x1 Conv layer { nnx_path_str } matched template but no LoRA weights found for any block ." )
365+ max_logging .log (f"Scanned 1x1 Conv layer { nnx_path_str } matched template but no LoRA weights found." )
342366 else :
343367 max_logging .log (f"Scanned 1x1 Conv layer '{ nnx_path_str } ' could not be translated to a LoRA key template." )
344368
369+
345370 max_logging .log (f"Merged weights into { assigned_count } scanned layers in { type (model ).__name__ } ." )
0 commit comments