@@ -206,66 +206,87 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
206206 into the kernel of nnx.Linear and nnx.Conv layers.
207207 """
208208 lora_params = {}
209+ # Parse weights and alphas
209210 for k , v in state_dict .items ():
210- # Try matching diffusers rename format: "some.thing_lora.down.weight"
211- m = re .match (r"^(.*?)_lora\.(down|up)\.weight$" , k )
212- if m :
213- module_path_str , weight_type = m .group (1 ), m .group (2 )
214- else :
215- # Try matching diffusers format: "some.thing.lora.down.weight"
216- m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
217- if m :
218- module_path_str , weight_type = m .group (1 ), m .group (2 )
219- else :
220- # Try matching kohya/lightning format: "some.thing.lora_down.weight"
221- m = re .match (r"^(.*?)\.(lora_down|lora_up)\.weight$" , k )
222- if m :
223- module_path_str , weight_type = m .group (1 ), m .group (2 ).replace ("lora_" , "" )
224- else :
225- max_logging .log (f"Could not parse LoRA key: { k } " )
226- continue
227-
228- if module_path_str not in lora_params :
229- lora_params [module_path_str ] = {}
230- lora_params [module_path_str ][weight_type ] = jnp .array (v )
211+ if k .endswith (".alpha" ):
212+ module_path_str = k [: - len (".alpha" )]
213+ if module_path_str not in lora_params :
214+ lora_params [module_path_str ] = {}
215+ lora_params [module_path_str ]["alpha" ] = jnp .array (v )
216+ continue
217+
218+ # Try matching diffusers rename format: "some.thing_lora.down.weight"
219+ m = re .match (r"^(.*?)_lora\.(down|up)\.weight$" , k )
220+ if m :
221+ module_path_str , weight_type = m .group (1 ), m .group (2 )
222+ else :
223+ # Try matching diffusers format: "some.thing.lora.down.weight"
224+ m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
225+ if m :
226+ module_path_str , weight_type = m .group (1 ), m .group (2 )
227+ else :
228+ # Try matching kohya/lightning format: "some.thing.lora_down.weight"
229+ m = re .match (r"^(.*?)\.(lora_down|lora_up)\.weight$" , k )
230+ if m :
231+ module_path_str , weight_type = m .group (1 ), m .group (2 ).replace ("lora_" , "" )
232+ else :
233+ max_logging .log (f"Could not parse LoRA key: { k } " )
234+ continue
235+ if module_path_str not in lora_params :
236+ lora_params [module_path_str ] = {}
237+ lora_params [module_path_str ][weight_type ] = jnp .array (v )
238+ max_logging .log (f"Parsed { len (lora_params )} unique LoRA module keys: { list (lora_params .keys ())} " )
231239
232240 assigned_count = 0
233241 for path , module in nnx .iter_graph (model ):
242+ if not isinstance (module , (nnx .Linear , nnx .Conv )):
243+ max_logging .log (f"Skipping non-Linear/Conv layer: { module } " )
244+ continue
245+
234246 nnx_path_str = "." .join (map (str , path ))
247+ max_logging .log (f"Checking NNX layer: { nnx_path_str } " )
235248
236249 matched_key = None
237250 if nnx_path_str in lora_params :
238251 matched_key = nnx_path_str
239252 else :
240- # Fallback: check if any param key matches end of nnx path
253+ # Fallback: check if any param key is a suffix of nnx path
241254 for k in lora_params :
242255 if nnx_path_str .endswith (k ):
243256 matched_key = k
257+ max_logging .log (f"NNX path '{ nnx_path_str } ' matched LoRA key '{ k } ' via suffix." )
244258 break
259+ max_logging .log (f"Layer: { nnx_path_str } , Matched LoRA key: { matched_key } " )
245260
246261 if matched_key and matched_key in lora_params :
247- weights = lora_params [matched_key ]
248- if "down" in weights and "up" in weights :
249- if isinstance (module , nnx .Linear ):
250- down_w = weights ["down" ] # (rank, in_features)
251- up_w = weights ["up" ] # (out_features_flat, rank)
252- # delta = A@B = down.T @ up.T
253- delta = (down_w .T @ up_w .T ).reshape (module .kernel .shape )
254- module .kernel .value += delta * scale
255- assigned_count += 1
256- elif isinstance (module , nnx .Conv ):
257- if module .kernel_size == (1 , 1 ):
258- down_w = weights ["down" ] # (1,1,in_c,rank)
259- up_w = weights ["up" ] # (1,1,rank,out_c)
260- # delta = down @ up for channel dimension
261- delta = (jnp .squeeze (down_w ) @ jnp .squeeze (up_w )).reshape (module .kernel .shape )
262- module .kernel .value += delta * scale
263- assigned_count += 1
264- else :
265- raise NotImplementedError (
266- f"Merging LoRA weights for Conv layer { matched_key } "
267- f"with kernel_size { module .kernel_size } > 1 is not supported."
268- )
269- else :
270- max_logging .log (f"LoRA weights for { matched_key } incomplete." )
262+ weights = lora_params [matched_key ]
263+ if "down" in weights and "up" in weights :
264+ if isinstance (module , nnx .Linear ):
265+ down_w = weights ["down" ] # (rank, in_features)
266+ up_w = weights ["up" ] # (out_features_flat, rank)
267+ rank = down_w .shape [0 ]
268+ alpha = weights .get ("alpha" , rank )
269+ current_scale = scale * alpha / rank
270+ # delta = A@B = down.T @ up.T
271+ delta = (down_w .T @ up_w .T ).reshape (module .kernel .shape )
272+ module .kernel .value += delta * current_scale
273+ assigned_count += 1
274+ elif isinstance (module , nnx .Conv ):
275+ if module .kernel_size == (1 , 1 ):
276+ down_w = weights ["down" ] # (1,1,in_c,rank)
277+ up_w = weights ["up" ] # (1,1,rank,out_c)
278+ rank = down_w .shape [- 1 ]
279+ alpha = weights .get ("alpha" , rank )
280+ current_scale = scale * alpha / rank
281+ # delta = down @ up for channel dimension
282+ delta = (jnp .squeeze (down_w ) @ jnp .squeeze (up_w )).reshape (module .kernel .shape )
283+ module .kernel .value += delta * current_scale
284+ assigned_count += 1
285+ else :
286+ raise NotImplementedError (
287+ f"Merging LoRA weights for Conv layer { matched_key } "
288+ f"with kernel_size { module .kernel_size } > 1 is not supported."
289+ )
290+ else :
291+ max_logging .log (f"LoRA weights for { matched_key } incomplete." )
271292 max_logging .log (f"Merged weights into { assigned_count } layers in { type (model ).__name__ } ." )
0 commit comments