Skip to content

Commit ef35f7d

Browse files
committed
JIT whole diffusion loop
1 parent 998edf4 commit ef35f7d

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,8 +1238,8 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12381238

12391239
import time
12401240
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1241-
for i, t_val in enumerate(timesteps):
1242-
t = timesteps_jax[i]
1241+
def step_fn(carry, t):
1242+
latents_jax, audio_latents_jax = carry
12431243

12441244
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
12451245
latents_jax_sharded = latents_jax
@@ -1293,11 +1293,16 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12931293
)
12941294

12951295
if guidance_scale > 1.0:
1296-
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
1297-
audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)
1296+
new_latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
1297+
new_audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)
12981298
else:
1299-
latents_jax = latents_step
1300-
audio_latents_jax = audio_latents_step
1299+
new_latents_jax = latents_step
1300+
new_audio_latents_jax = audio_latents_step
1301+
1302+
return (new_latents_jax, new_audio_latents_jax), None
1303+
1304+
initial_carry = (latents_jax, audio_latents_jax)
1305+
(latents_jax, audio_latents_jax), _ = jax.lax.scan(step_fn, initial_carry, timesteps_jax)
13011306

13021307

13031308

0 commit comments

Comments
 (0)