Skip to content

Commit 45dbfb3

Browse files
committed
lint
1 parent 6658df7 commit 45dbfb3

1 file changed

Lines changed: 12 additions & 10 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,9 +1346,9 @@ def __call__(
13461346
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
13471347

13481348
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1349-
1349+
13501350
scan_diffusion_loop = getattr(self.config, "scan_diffusion_loop", True)
1351-
1351+
13521352
if scan_diffusion_loop:
13531353
latents_jax, audio_latents_jax = run_diffusion_loop(
13541354
graphdef,
@@ -1375,7 +1375,7 @@ def __call__(
13751375
# Old Python loop path
13761376
latents_jax = latents_jax.astype(jnp.float32)
13771377
audio_latents_jax = audio_latents_jax.astype(jnp.float32)
1378-
1378+
13791379
for t in timesteps_jax:
13801380
noise_pred, noise_pred_audio = transformer_forward_pass(
13811381
graphdef,
@@ -1395,26 +1395,28 @@ def __call__(
13951395
audio_num_frames,
13961396
frame_rate,
13971397
)
1398-
1398+
13991399
if guidance_scale > 1.0:
14001400
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
14011401
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1402-
1402+
14031403
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
14041404
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1405-
1405+
14061406
latents_step = latents_jax[batch_size:]
14071407
audio_latents_step = audio_latents_jax[batch_size:]
14081408
else:
14091409
latents_step = latents_jax
14101410
audio_latents_step = audio_latents_jax
1411-
1411+
14121412
latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False)
14131413
latents_step = latents_step.astype(jnp.float32)
1414-
1415-
audio_latents_step, _ = self.scheduler.step(scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False)
1414+
1415+
audio_latents_step, _ = self.scheduler.step(
1416+
scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False
1417+
)
14161418
audio_latents_step = audio_latents_step.astype(jnp.float32)
1417-
1419+
14181420
if guidance_scale > 1.0:
14191421
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
14201422
audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)

0 commit comments

Comments
 (0)