@@ -207,17 +207,27 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
207207 """
208208 lora_params = {}
209209 for k , v in state_dict .items ():
210+ # Try matching diffusers rename format: "some.thing_lora.down.weight"
210211 m = re .match (r"^(.*?)_lora\.(down|up)\.weight$" , k )
211- if not m :
212- m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
213-
214212 if m :
215- module_path_str , weight_type = m .group (1 ), m .group (2 )
216- if module_path_str not in lora_params :
217- lora_params [module_path_str ] = {}
218- lora_params [module_path_str ][weight_type ] = jnp .array (v )
213+ module_path_str , weight_type = m .group (1 ), m .group (2 )
219214 else :
220- max_logging .log (f"Could not parse LoRA key: { k } " )
215+ # Try matching diffusers format: "some.thing.lora.down.weight"
216+ m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
217+ if m :
218+ module_path_str , weight_type = m .group (1 ), m .group (2 )
219+ else :
220+ # Try matching kohya/lightning format: "some.thing.lora_down.weight"
221+ m = re .match (r"^(.*?)\.(lora_down|lora_up)\.weight$" , k )
222+ if m :
223+ module_path_str , weight_type = m .group (1 ), m .group (2 ).replace ("lora_" , "" )
224+ else :
225+ max_logging .debug (f"Could not parse LoRA key: { k } " )
226+ continue
227+
228+ if module_path_str not in lora_params :
229+ lora_params [module_path_str ] = {}
230+ lora_params [module_path_str ][weight_type ] = jnp .array (v )
221231
222232 assigned_count = 0
223233 for path , module in nnx .iter_graph (model ):
@@ -257,5 +267,5 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float):
257267 f"with kernel_size { module .kernel_size } > 1 is not supported."
258268 )
259269 else :
260- max_logging .log (f"LoRA weights for { matched_key } incomplete." )
270+ max_logging .warning (f"LoRA weights for { matched_key } incomplete." )
261271 max_logging .log (f"Merged weights into { assigned_count } layers in { type (model ).__name__ } ." )
0 commit comments