Skip to content

Commit d913afd

Browse files
committed
Fix
1 parent f98b1ee commit d913afd

1 file changed

Lines changed: 12 additions & 4 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
from typing import Union, Tuple, Optional
1818
import re
19+
import torch
20+
import torch.utils.dlpack
21+
from jax import dlpack
1922
import jax.numpy as jnp
2023
from flax import nnx
2124
from .. import max_logging
@@ -200,6 +203,11 @@ def inject_lora(
200203

201204
return model
202205

206+
def _to_jax_array(v):
207+
if isinstance(v, torch.Tensor):
208+
return dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(v))
209+
return jnp.array(v)
210+
203211
def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=None):
204212
"""
205213
Merges weights from a Diffusers-formatted state dict directly
@@ -213,7 +221,7 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
213221
key_base = k[:-len(".alpha")]
214222
if key_base not in lora_params:
215223
lora_params[key_base] = {}
216-
lora_params[key_base]["alpha"] = jnp.array(v)
224+
lora_params[key_base]["alpha"] = _to_jax_array(v)
217225
continue
218226

219227
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
@@ -232,7 +240,7 @@ def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=N
232240
continue
233241
if key_base not in lora_params:
234242
lora_params[key_base] = {}
235-
lora_params[key_base][weight_type] = jnp.array(v)
243+
lora_params[key_base][weight_type] = _to_jax_array(v)
236244
max_logging.log(f"Parsed {len(lora_params)} unique LoRA module keys.")
237245

238246
assigned_count = 0
@@ -289,7 +297,7 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
289297
key_base = k[:-len(".alpha")]
290298
if key_base not in lora_params:
291299
lora_params[key_base] = {}
292-
lora_params[key_base]["alpha"] = jnp.array(v)
300+
lora_params[key_base]["alpha"] = _to_jax_array(v)
293301
continue
294302

295303
m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k)
@@ -308,7 +316,7 @@ def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, scale: float, tr
308316
continue
309317
if key_base not in lora_params:
310318
lora_params[key_base] = {}
311-
lora_params[key_base][weight_type] = jnp.array(v)
319+
lora_params[key_base][weight_type] = _to_jax_array(v)
312320
max_logging.log(f"Parsed {len(lora_params)} unique LoRA module keys for scanned merge.")
313321

314322
assigned_count = 0

0 commit comments

Comments
 (0)