File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -121,19 +121,22 @@ def parse_lora_dict(state_dict, dtype):
121121 continue
122122
123123 # Standard LoRA
124- m = re .match (r"^(.*?)_lora \.(down|up )\.weight$" , k )
124+ m = re .match (r"^(.*?)\.(lora_down|lora_up )\.weight$" , k )
125125 if not m :
126- m = re .match (r"^(.*?)\.lora\.(down|up)\.weight$" , k )
127- if not m :
128- m = re .match (r"^(.*?)\.(lora_down|lora_up)\.weight$" , k )
126+ m = re .match (r"^(.*?)\.(lora_A|lora_B)\.weight$" , k )
129127
130128 if m :
131129 key_base , weight_type = m .group (1 ), m .group (2 ).replace ("lora_" , "" )
130+ if weight_type == "A" :
131+ weight_type = "down"
132+ elif weight_type == "B" :
133+ weight_type = "up"
132134 if key_base not in lora_params :
133135 lora_params [key_base ] = {}
134136 lora_params [key_base ][weight_type ] = _to_jax_array (v , dtype = dtype )
135137 else :
136138 # Fallback for exact matches of diffs if regex failed above
139+ max_logging .log (f"Key { k } did not match any LoRA pattern." )
137140 pass
138141
139142 return lora_params
You can’t perform that action at this time.
0 commit comments