Skip to content

Commit 87817d0

Browse files
Merge branch 'main' into wan_transformer
2 parents 56f5225 + 0af353d commit 87817d0

3 files changed

Lines changed: 669 additions & 773 deletions

File tree

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def torch2jax(torch_tensor: torch.Tensor) -> Array:
6666
torch_tensor = torch_tensor.to("cpu")
6767

6868
numpy_value = torch_tensor.numpy()
69-
jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None)
69+
local_cpu_device_0 = jax.local_devices(backend="cpu")[0]
70+
jax_array = jnp.array(numpy_value, dtype=jnp.bfloat16 if is_bfloat16 else None, device=local_cpu_device_0)
7071
return jax_array
7172

7273

0 commit comments

Comments
 (0)