@@ -225,7 +225,6 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
225225 assigned_count = 0
226226 for path , module in nnx .iter_graph (model ):
227227 if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed )):
228- max_logging .log (f"Skipping non-supported module type: { module } " )
229228 continue
230229
231230 nnx_path_str = "." .join (map (str , path ))
@@ -234,6 +233,10 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
234233 if lora_key and lora_key in lora_params :
235234 weights = lora_params [lora_key ]
236235
236+ is_conv_kxk_locon = False
237+ if isinstance (module , nnx .Conv ) and module .kernel_size != (1 ,1 ) and "down" in weights and "up" in weights :
238+ is_conv_kxk_locon = True
239+
237240 # Handle Embeddings
238241 if isinstance (module , nnx .Embed ):
239242 if "diff" in weights and hasattr (module , 'embedding' ):
@@ -257,20 +260,17 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
257260
258261 # Prepare LoRA terms
259262 down_w , up_w , current_scale = None , None , None
260- if "down" in weights and "up" in weights :
261- if isinstance (module , nnx .Conv ) and module .kernel_size != (1 , 1 ):
262- max_logging .log (f"Skipping LoRA merge for non-1x1 Conv: { lora_key } " )
263- else :
264- down_w , up_w = weights ["down" ], weights ["up" ]
265- down_w , up_w = np .array (down_w ), np .array (up_w ) # CPU convert
266-
267- # Squeeze dimensions if needed (Conv 1x1 or Linear)
268- if isinstance (module , nnx .Conv ) and module .kernel_size == (1 , 1 ):
269- down_w , up_w = np .squeeze (down_w ), np .squeeze (up_w )
263+ if "down" in weights and "up" in weights and not is_conv_kxk_locon :
264+ down_w , up_w = weights ["down" ], weights ["up" ]
265+ down_w , up_w = np .array (down_w ), np .array (up_w ) # CPU convert
266+
267+ # Squeeze dimensions if needed (Conv 1x1 or Linear)
268+ if isinstance (module , nnx .Conv ) and module .kernel_size == (1 , 1 ):
269+ down_w , up_w = np .squeeze (down_w ), np .squeeze (up_w )
270270
271- rank = down_w .shape [0 ] if down_w .ndim > 0 else 0
272- alpha = float (weights .get ("alpha" , rank ))
273- current_scale = scale * alpha / rank
271+ rank = down_w .shape [0 ] if down_w .ndim > 0 else 0
272+ alpha = float (weights .get ("alpha" , rank ))
273+ current_scale = scale * alpha / rank
274274
275275 # Prepare Diff terms
276276 w_diff = weights .get ("diff" , None )
@@ -288,6 +288,25 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
288288 w_diff = w_diff .transpose ((1 ,0 ))
289289 if b_diff is not None : b_diff = np .array (b_diff )
290290
291+ # If LoCON, compute delta and add to w_diff
292+ if is_conv_kxk_locon :
293+ dw , uw = np .array (weights ['down' ]), np .array (weights ['up' ])
294+ rank , in_c , * k_dims = dw .shape
295+ out_c = uw .shape [0 ]
296+ alpha = float (weights .get ("alpha" , rank ))
297+
298+ delta_pt = (uw .reshape (out_c , rank ) @ dw .reshape (rank , - 1 )).reshape (out_c , in_c , * k_dims )
299+
300+ # Transpose to flax
301+ if delta_pt .ndim == 5 : delta_fx = delta_pt .transpose ((2 ,3 ,4 ,1 ,0 ))
302+ else : delta_fx = delta_pt .transpose ((2 ,3 ,1 ,0 ))
303+
304+ lora_delta = delta_fx * (scale * alpha / rank )
305+ if w_diff is None :
306+ w_diff = lora_delta .astype (np .float32 )
307+ else :
308+ w_diff += lora_delta .astype (w_diff .dtype )
309+
291310 # Check for Bias existence
292311 bias_val = module .bias .value if module .bias is not None else None
293312
@@ -322,7 +341,6 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
322341 assigned_count = 0
323342 for path , module in nnx .iter_graph (model ):
324343 if not isinstance (module , (nnx .Linear , nnx .Conv , nnx .LayerNorm , nnx .RMSNorm , nnx .Embed )):
325- max_logging .log (f"Skipping non-supported module type: { module } " )
326344 continue
327345
328346 nnx_path_str = "." .join (map (str , path ))
@@ -373,10 +391,11 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
373391 is_linear = isinstance (module , nnx .Linear ) and module .kernel .ndim == 3
374392 is_conv = isinstance (module , nnx .Conv ) and module .kernel .ndim == 5
375393
394+ is_conv_kxk = isinstance (module , nnx .Conv ) and module .kernel_size != (1 ,1 )
395+
376396 if is_linear :
377397 num_layers , in_feat , out_feat = module .kernel .shape
378398 elif is_conv :
379- if module .kernel_size != (1 , 1 ): continue
380399 num_layers = module .kernel .shape [0 ]
381400 in_feat , out_feat = module .kernel .shape [3 ], module .kernel .shape [4 ]
382401 else :
@@ -412,12 +431,29 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
412431 # --- Fill LoRA ---
413432 if "down" in w :
414433 d , u = np .array (w ["down" ]), np .array (w ["up" ])
415- if d .ndim > 2 : d = np .squeeze (d )
416- if u .ndim > 2 : u = np .squeeze (u )
417- stack_down [i ] = d
418- stack_up [i ] = u
419- stack_alpha [i ] = float (w .get ("alpha" , d .shape [0 ]))
420- has_lora = True
434+ alpha = float (w .get ("alpha" , d .shape [0 ]))
435+ rank = d .shape [0 ]
436+
437+ if is_conv_kxk :
438+ # For LoCON kxk, compute delta and merge into stack_w_diff
439+ rank , in_c , * k_dims = d .shape
440+ out_c = u .shape [0 ]
441+ delta_pt = (u .reshape (out_c , rank ) @ d .reshape (rank , - 1 )).reshape (out_c , in_c , * k_dims )
442+ if delta_pt .ndim == 5 : delta_fx = delta_pt .transpose ((2 ,3 ,4 ,1 ,0 ))
443+ else : delta_fx = delta_pt .transpose ((2 ,3 ,1 ,0 ))
444+
445+ lora_delta = delta_fx * (scale * alpha / rank )
446+ if stack_w_diff is None : stack_w_diff = np .zeros (module .kernel .shape , dtype = np .float32 )
447+ stack_w_diff [i ] += lora_delta .astype (stack_w_diff .dtype )
448+ has_diff = True # Mark as having diff because we merged LoRA into w_diff
449+ else :
450+ # For Linear or 1x1 Conv, prepare for JIT
451+ if d .ndim > 2 : d = np .squeeze (d )
452+ if u .ndim > 2 : u = np .squeeze (u )
453+ stack_down [i ] = d
454+ stack_up [i ] = u
455+ stack_alpha [i ] = alpha
456+ has_lora = True
421457
422458 # --- Fill Weight Diff ---
423459 if "diff" in w :
@@ -433,7 +469,7 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
433469 elif is_linear and wd .ndim == 2 :
434470 wd = wd .transpose ((1 ,0 ))
435471
436- stack_w_diff [i ] = wd
472+ stack_w_diff [i ] + = wd
437473 has_diff = True
438474
439475 # --- Fill Bias Diff ---
0 commit comments