@@ -200,40 +200,7 @@ def inject_lora(
200200
201201 return model
202202
203- def _translate_nnx_path_to_lora_key (nnx_path_str ):
204- """
205- Translates NNX path like 'blocks.10.attn1.key' to
206- LoRA path like 'diffusion_model.blocks.10.self_attn.k'.
207- Returns None if no match.
208- """
209- translation_map = {
210- "attn1" : "self_attn" ,
211- "attn2" : "cross_attn" ,
212- "query" : "q" ,
213- "key" : "k" ,
214- "value" : "v" ,
215- "proj_attn" : "o" ,
216- "ffn.act_fn.proj" : "ffn.0" ,
217- "ffn.proj_out" : "ffn.2" ,
218- }
219- # Match paths like blocks.10.attn1.key or blocks.5.ffn.proj_out
220- m = re .match (r"^blocks\.(\d+)\.(attn[12]\.(?:query|key|value|proj_attn)|ffn\.(?:act_fn\.proj|proj_out))$" , nnx_path_str )
221- if not m :
222- return None
223-
224- block_idx , suffix = m .group (1 ), m .group (2 )
225-
226- parts = suffix .split ('.' )
227- if parts [0 ] == 'attn1' or parts [0 ] == 'attn2' :
228- lora_part1 = translation_map [parts [0 ]]
229- lora_part2 = translation_map [parts [1 ]]
230- return f"diffusion_model.blocks.{ block_idx } .{ lora_part1 } .{ lora_part2 } "
231- elif suffix in translation_map :
232- return f"diffusion_model.blocks.{ block_idx } .{ translation_map [suffix ]} "
233- return None
234-
235-
236- def merge_lora (model : nnx .Module , state_dict : dict , scale : float ):
203+ def merge_lora (model : nnx .Module , state_dict : dict , scale : float , translate_fn = None ):
237204 """
238205 Merges weights from a Diffusers-formatted state dict directly
239206 into the kernel of nnx.Linear and nnx.Conv layers.
@@ -271,15 +238,12 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
271238 assigned_count = 0
272239 for path , module in nnx .iter_graph (model ):
273240 if not isinstance (module , (nnx .Linear , nnx .Conv )):
274- max_logging .log (f"Skipping non-Linear/Conv layer: { module } " )
275241 continue
276242
277243 nnx_path_str = "." .join (map (str , path ))
278- max_logging .log (f"Checking NNX layer: { nnx_path_str } " )
279- lora_key = _translate_nnx_path_to_lora_key (nnx_path_str )
244+ lora_key = translate_fn (nnx_path_str ) if translate_fn else None
280245
281246 if lora_key and lora_key in lora_params :
282- max_logging .log (f"NNX layer '{ nnx_path_str } ' matched LoRA key '{ lora_key } '" )
283247 weights = lora_params [lora_key ]
284248 if "down" in weights and "up" in weights :
285249 if isinstance (module , nnx .Linear ):
@@ -308,4 +272,85 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
308272 else :
309273 max_logging .log (f"NNX layer '{ nnx_path_str } ' could not be translated to a LoRA key." )
310274
311- max_logging .log (f"Merged weights into { assigned_count } layers in { type (model ).__name__ } ." )
275+ max_logging .log (f"Merged weights into { assigned_count } layers in { type (model ).__name__ } ." )
276+
277+
278+ def merge_lora_for_scanned (model : nnx .Module , state_dict : dict , scale : float , translate_fn = None ):
279+ """
280+ Merges weights from a Diffusers-formatted state dict directly
281+ into the kernel of nnx.Linear and nnx.Conv layers.
282+ Assumes scan_layers=True, so weights are stacked if layers are scanned
283+ (e.g. kernel.ndim=3 for Linear).
284+ """
285+ lora_params = {}
286+ # Parse weights and alphas
287+ for k , v in state_dict .items ():
288+ if k .endswith (".alpha" ):
289+ key_base = k [:- len (".alpha" )]
290+ if key_base not in lora_params :
291+ lora_params [key_base ] = {}
292+ lora_params [key_base ]["alpha" ] = jnp .array (v )
293+ continue
294+
295+ m = re .match (r"^(.*?)_lora\.(down|up)\.weight$" , k )
296+ if m :
297+ key_base , weight_type = m .group (1 ), m .group (2 )
298+ else :
299+ m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
300+ if m :
301+ key_base , weight_type = m .group (1 ), m .group (2 )
302+ else :
303+ m = re .match (r"^(.*?)\.(lora_down|lora_up)\.weight$" , k )
304+ if m :
305+ key_base , weight_type = m .group (1 ), m .group (2 ).replace ("lora_" , "" )
306+ else :
307+ max_logging .log (f"Could not parse LoRA key: { k } " )
308+ continue
309+ if key_base not in lora_params :
310+ lora_params [key_base ] = {}
311+ lora_params [key_base ][weight_type ] = jnp .array (v )
312+ max_logging .log (f"Parsed { len (lora_params )} unique LoRA module keys for scanned merge." )
313+
314+ assigned_count = 0
315+ for path , module in nnx .iter_graph (model ):
316+ if not isinstance (module , (nnx .Linear , nnx .Conv )):
317+ continue
318+
319+ nnx_path_str = "." .join (map (str , path ))
320+
321+ # Handle scanned Linear layers
322+ if isinstance (module , nnx .Linear ) and module .kernel .ndim == 3 :
323+ lora_key_template = translate_fn (nnx_path_str ) if translate_fn else None
324+
325+ if lora_key_template :
326+ num_layers , in_features , out_features = module .kernel .shape
327+ deltas = []
328+ has_lora = False
329+ for i in range (num_layers ):
330+ lora_key = lora_key_template .format (i )
331+ if lora_key in lora_params and "down" in lora_params [lora_key ] and "up" in lora_params [lora_key ]:
332+ weights = lora_params [lora_key ]
333+ down_w , up_w = weights ["down" ], weights ["up" ]
334+ rank = down_w .shape [0 ]
335+ alpha = weights .get ("alpha" , rank )
336+ current_scale = scale * alpha / rank
337+ delta_i = (down_w .T @ up_w .T ).reshape (in_features , out_features ) * current_scale
338+ deltas .append (delta_i )
339+ has_lora = True
340+ else :
341+ deltas .append (jnp .zeros ((in_features , out_features ), dtype = module .kernel .dtype ))
342+
343+ if has_lora :
344+ stacked_delta = jnp .stack (deltas , axis = 0 )
345+ module .kernel .value += stacked_delta
346+ assigned_count += 1
347+ else :
348+ max_logging .log (f"Scanned layer { nnx_path_str } matched template but no LoRA weights found for any block." )
349+ else :
350+ max_logging .log (f"Scanned NNX layer '{ nnx_path_str } ' could not be translated to a LoRA key template." )
351+
352+ # Handle scanned Conv layers (ndim=5)
353+ elif isinstance (module , nnx .Conv ) and module .kernel .ndim == 5 :
354+ max_logging .warn (f"Merging LoRA into scanned Conv layers not implemented: { nnx_path_str } " )
355+
356+ max_logging .log (f"Merged weights into { assigned_count } scanned layers in { type (model ).__name__ } ." )
0 commit comments