We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4a8155e commit 21ce2fbCopy full SHA for 21ce2fb
1 file changed
src/maxdiffusion/models/modeling_flax_pytorch_utils.py
@@ -66,7 +66,8 @@ def torch2jax(torch_tensor: torch.Tensor) -> Array:
66
torch_tensor = torch_tensor.to("cpu")
67
68
numpy_value = torch_tensor.numpy()
69
- jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None)
+ local_cpu_device_0 = jax.local_devices(backend="cpu")[0]
70
+ jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None, device=local_cpu_device_0)
71
return jax_array
72
73
0 commit comments