Skip to content

Commit c905459

Browse files
committed
converting timesteps to jax array
1 parent 5f95986 commit c905459

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,9 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12371237
)
12381238

12391239
import time
1240-
for i, t in enumerate(timesteps):
1240+
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1241+
for i, t_val in enumerate(timesteps):
1242+
t = timesteps_jax[i]
12411243

12421244
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
12431245
latents_jax_sharded = latents_jax

0 commit comments

Comments
 (0)