Skip to content

Commit 8758e29

Browse files
committed
Fix latents error
1 parent cdec1b4 commit 8758e29

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,8 +1079,8 @@ def adain_filter_latent(latents: jnp.ndarray, reference_latents: jnp.ndarray, fa
10791079
jax.Array: The transformed latent tensor.
10801080
"""
10811081
with default_env():
1082-
latents = jax.device_put(latents, jax.devices("tpu")[0])
1083-
reference_latents = jax.device_put(reference_latents, jax.devices("tpu")[0])
1082+
latents = jax.device_put(jax.numpy.array(latents), jax.devices("tpu")[0])
1083+
reference_latents = jax.device_put(jax.numpy.array(reference_latents), jax.devices("tpu")[0])
10841084

10851085
# Define the core AdaIN operation for a single (F, H, W) slice.
10861086
# This function will be vmapped over batch (B) and channel (C) dimensions.

0 commit comments

Comments
 (0)