Skip to content

Commit 34454fb

Browse files
authored
Misc fixes (#178)
* Torch2jax, explicit create on cpu * Format wan scheduler/unipc multistep Signed-off-by: Kunjan <kunjanp@google.com> --------- Signed-off-by: Kunjan <kunjanp@google.com>
1 parent 4a8155e commit 34454fb

2 files changed

Lines changed: 650 additions & 753 deletions

File tree

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def torch2jax(torch_tensor: torch.Tensor) -> Array:
6666
torch_tensor = torch_tensor.to("cpu")
6767

6868
numpy_value = torch_tensor.numpy()
69-
jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None)
69+
local_cpu_device_0 = jax.local_devices(backend="cpu")[0]
70+
jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None, device=local_cpu_device_0)
7071
return jax_array
7172

7273

0 commit comments

Comments
 (0)