Skip to content

Commit 1e594b5

Browse files
committed
Add support for AI toolkit lora format
1 parent b993abf commit 1e594b5

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,19 +121,22 @@ def parse_lora_dict(state_dict, dtype):
121121
continue
122122

123123
# Standard LoRA
124-
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
124+
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
125125
if not m:
126-
m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k)
127-
if not m:
128-
m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k)
126+
m = re.match(r"^(.*?)\.(lora_A|lora_B)\.weight$", k)
129127

130128
if m:
131129
key_base, weight_type = m.group(1), m.group(2).replace("lora_", "")
130+
if weight_type == "A":
131+
weight_type = "down"
132+
elif weight_type == "B":
133+
weight_type = "up"
132134
if key_base not in lora_params:
133135
lora_params[key_base] = {}
134136
lora_params[key_base][weight_type] = _to_jax_array(v, dtype=dtype)
135137
else:
136138
# Fallback for exact matches of diffs if regex failed above
139+
max_logging.log(f"Key {k} did not match any LoRA pattern.")
137140
pass
138141

139142
return lora_params

0 commit comments

Comments
 (0)