Skip to content

Commit eb330d2

Browse files
committed
Fix
1 parent 287552b commit eb330d2

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/maxdiffusion/models/lora_nnx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Union, Tuple, Optional
1818
import re
1919
import torch
20+
import jax
2021
from jax import dlpack
2122
import jax.numpy as jnp
2223
from flax import nnx
@@ -155,8 +156,8 @@ def __call__(self, x):
155156

156157
def _to_jax_array(v):
157158
if isinstance(v, torch.Tensor):
158-
return dlpack.from_dlpack(v)
159-
return jnp.array(v)
159+
return jax.device_put(dlpack.from_dlpack(v))
160+
return jax.device_put(jnp.array(v))
160161

161162
def merge_lora(model: nnx.Module, state_dict: dict, scale: float, translate_fn=None):
162163
"""

0 commit comments

Comments
 (0)