@@ -200,42 +200,73 @@ def inject_lora(
200200
201201 return model
202202
203+ def _translate_nnx_path_to_lora_key (nnx_path_str ):
204+ """
205+ Translates NNX path like 'blocks.10.attn1.key' to
206+ LoRA path like 'diffusion_model.blocks.10.self_attn.k'.
207+ Returns None if no match.
208+ """
209+ translation_map = {
210+ "attn1" : "self_attn" ,
211+ "attn2" : "cross_attn" ,
212+ "query" : "q" ,
213+ "key" : "k" ,
214+ "value" : "v" ,
215+ "proj_attn" : "o" ,
216+ "ffn.act_fn.proj" : "ffn.0" ,
217+ "ffn.proj_out" : "ffn.2" ,
218+ }
219+ # Match paths like blocks.10.attn1.key or blocks.5.ffn.proj_out
220+ m = re .match (r"^blocks\.(\d+)\.(attn[12]\.(?:query|key|value|proj_attn)|ffn\.(?:act_fn\.proj|proj_out))$" , nnx_path_str )
221+ if not m :
222+ return None
223+
224+ block_idx , suffix = m .group (1 ), m .group (2 )
225+
226+ parts = suffix .split ('.' )
227+ if parts [0 ] == 'attn1' or parts [0 ] == 'attn2' :
228+ lora_part1 = translation_map [parts [0 ]]
229+ lora_part2 = translation_map [parts [1 ]]
230+ return f"diffusion_model.blocks.{ block_idx } .{ lora_part1 } .{ lora_part2 } "
231+ elif suffix in translation_map :
232+ return f"diffusion_model.blocks.{ block_idx } .{ translation_map [suffix ]} "
233+ return None
234+
235+
203236def merge_lora (model : nnx .Module , state_dict : dict , scale : float ):
204237 """
205238 Merges weights from a Diffusers-formatted state dict directly
206239 into the kernel of nnx.Linear and nnx.Conv layers.
240+ Assumes scan_layers=False, so NNX paths include block indices.
207241 """
208242 lora_params = {}
209243 # Parse weights and alphas
210244 for k , v in state_dict .items ():
211245 if k .endswith (".alpha" ):
212- module_path_str = k [: - len (".alpha" )]
213- if module_path_str not in lora_params :
214- lora_params [module_path_str ] = {}
215- lora_params [module_path_str ]["alpha" ] = jnp .array (v )
246+ key_base = k [:- len (".alpha" )]
247+ if key_base not in lora_params :
248+ lora_params [key_base ] = {}
249+ lora_params [key_base ]["alpha" ] = jnp .array (v )
216250 continue
217251
218- # Try matching diffusers rename format: "some.thing_lora.down.weight"
219252 m = re .match (r"^(.*?)_lora\.(down|up)\.weight$" , k )
220253 if m :
221- module_path_str , weight_type = m .group (1 ), m .group (2 )
254+ key_base , weight_type = m .group (1 ), m .group (2 )
222255 else :
223- # Try matching diffusers format: "some.thing.lora.down.weight"
224256 m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
225257 if m :
226- module_path_str , weight_type = m .group (1 ), m .group (2 )
258+ key_base , weight_type = m .group (1 ), m .group (2 )
227259 else :
228- # Try matching kohya/lightning format: "some.thing.lora_down.weight"
229260 m = re .match (r"^(.*?)\.(lora_down|lora_up)\.weight$" , k )
230261 if m :
231- module_path_str , weight_type = m .group (1 ), m .group (2 ).replace ("lora_" , "" )
262+ key_base , weight_type = m .group (1 ), m .group (2 ).replace ("lora_" , "" )
232263 else :
233264 max_logging .log (f"Could not parse LoRA key: { k } " )
234265 continue
235- if module_path_str not in lora_params :
236- lora_params [module_path_str ] = {}
237- lora_params [module_path_str ][weight_type ] = jnp .array (v )
238- max_logging .log (f"Parsed { len (lora_params )} unique LoRA module keys: { list ( lora_params . keys ()) } " )
266+ if key_base not in lora_params :
267+ lora_params [key_base ] = {}
268+ lora_params [key_base ][weight_type ] = jnp .array (v )
269+ max_logging .log (f"Parsed { len (lora_params )} unique LoRA module keys. " )
239270
240271 assigned_count = 0
241272 for path , module in nnx .iter_graph (model ):
@@ -245,48 +276,36 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
245276
246277 nnx_path_str = "." .join (map (str , path ))
247278 max_logging .log (f"Checking NNX layer: { nnx_path_str } " )
279+ lora_key = _translate_nnx_path_to_lora_key (nnx_path_str )
248280
249- matched_key = None
250- if nnx_path_str in lora_params :
251- matched_key = nnx_path_str
252- else :
253- # Fallback: check if any param key is a suffix of nnx path
254- for k in lora_params :
255- if nnx_path_str .endswith (k ):
256- matched_key = k
257- max_logging .log (f"NNX path '{ nnx_path_str } ' matched LoRA key '{ k } ' via suffix." )
258- break
259- max_logging .log (f"Layer: { nnx_path_str } , Matched LoRA key: { matched_key } " )
260-
261- if matched_key and matched_key in lora_params :
262- weights = lora_params [matched_key ]
281+ if lora_key and lora_key in lora_params :
282+ max_logging .log (f"NNX layer '{ nnx_path_str } ' matched LoRA key '{ lora_key } '" )
283+ weights = lora_params [lora_key ]
263284 if "down" in weights and "up" in weights :
264285 if isinstance (module , nnx .Linear ):
265- down_w = weights ["down" ] # (rank, in_features)
266- up_w = weights ["up" ] # (out_features_flat, rank)
286+ down_w , up_w = weights ["down" ], weights ["up" ]
267287 rank = down_w .shape [0 ]
268288 alpha = weights .get ("alpha" , rank )
269289 current_scale = scale * alpha / rank
270- # delta = A@B = down.T @ up.T
271290 delta = (down_w .T @ up_w .T ).reshape (module .kernel .shape )
272291 module .kernel .value += delta * current_scale
273292 assigned_count += 1
274293 elif isinstance (module , nnx .Conv ):
275- if module .kernel_size == (1 , 1 ):
276- down_w = weights ["down" ] # (1,1,in_c,rank)
277- up_w = weights ["up" ] # (1,1,rank,out_c)
294+ if module .kernel_size == (1 , 1 ):
295+ down_w , up_w = weights ["down" ], weights ["up" ]
278296 rank = down_w .shape [- 1 ]
279297 alpha = weights .get ("alpha" , rank )
280298 current_scale = scale * alpha / rank
281- # delta = down @ up for channel dimension
282299 delta = (jnp .squeeze (down_w ) @ jnp .squeeze (up_w )).reshape (module .kernel .shape )
283300 module .kernel .value += delta * current_scale
284301 assigned_count += 1
285- else :
286- raise NotImplementedError (
287- f"Merging LoRA weights for Conv layer { matched_key } "
288- f"with kernel_size { module .kernel_size } > 1 is not supported."
289- )
302+ else :
303+ raise NotImplementedError (f"Conv merge only for 1x1 kernels, got { module .kernel_size } " )
290304 else :
291- max_logging .log (f"LoRA weights for { matched_key } incomplete." )
305+ max_logging .warning (f"LoRA weights for { lora_key } incomplete: missing down or up weights." )
306+ elif lora_key :
307+ max_logging .warning (f"NNX layer '{ nnx_path_str } ' translated to '{ lora_key } ' but key not in lora_params." )
308+ else :
309+ max_logging .debug (f"NNX layer '{ nnx_path_str } ' could not be translated to a LoRA key." )
310+
292311 max_logging .log (f"Merged weights into { assigned_count } layers in { type (model ).__name__ } ." )
0 commit comments