Skip to content

Commit a80b371

Browse files
committed
scan_diffusion_loop param added
1 parent 79e2b86 commit a80b371

2 files changed

Lines changed: 83 additions & 21 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@ a2v_attention_kernel: 'dot_product'
66
v2a_attention_kernel: 'dot_product'
77
attention_sharding_uniform: True
88
precision: 'bf16'
9+
10+
# For scanning transformer layers
911
scan_layers: True
12+
13+
# For scanning diffusion loop
14+
scan_diffusion_loop: True
15+
1016
names_which_can_be_saved: []
1117
names_which_can_be_offloaded: []
1218
remat_policy: "NONE"

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,27 +1346,83 @@ def __call__(
13461346
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
13471347

13481348
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1349-
latents_jax, audio_latents_jax = run_diffusion_loop(
1350-
graphdef,
1351-
state,
1352-
scheduler_state,
1353-
timesteps_jax,
1354-
latents_jax,
1355-
audio_latents_jax,
1356-
video_embeds_sharded,
1357-
audio_embeds_sharded,
1358-
new_attention_mask,
1359-
guidance_scale,
1360-
latent_num_frames,
1361-
latent_height,
1362-
latent_width,
1363-
audio_num_frames,
1364-
frame_rate,
1365-
batch_size,
1366-
self.transformer.scan_layers,
1367-
self.scheduler.step,
1368-
tuple(tuple(rule) if isinstance(rule, list) else rule for rule in self.config.logical_axis_rules),
1369-
)
1349+
1350+
scan_diffusion_loop = getattr(self.config, "scan_diffusion_loop", True)
1351+
1352+
if scan_diffusion_loop:
1353+
latents_jax, audio_latents_jax = run_diffusion_loop(
1354+
graphdef,
1355+
state,
1356+
scheduler_state,
1357+
timesteps_jax,
1358+
latents_jax,
1359+
audio_latents_jax,
1360+
video_embeds_sharded,
1361+
audio_embeds_sharded,
1362+
new_attention_mask,
1363+
guidance_scale,
1364+
latent_num_frames,
1365+
latent_height,
1366+
latent_width,
1367+
audio_num_frames,
1368+
frame_rate,
1369+
batch_size,
1370+
self.transformer.scan_layers,
1371+
self.scheduler.step,
1372+
tuple(tuple(rule) if isinstance(rule, list) else rule for rule in self.config.logical_axis_rules),
1373+
)
1374+
else:
1375+
# Old Python loop path
1376+
latents_jax = latents_jax.astype(jnp.float32)
1377+
audio_latents_jax = audio_latents_jax.astype(jnp.float32)
1378+
1379+
for t in timesteps_jax:
1380+
noise_pred, noise_pred_audio = transformer_forward_pass(
1381+
graphdef,
1382+
state,
1383+
latents_jax,
1384+
audio_latents_jax,
1385+
t,
1386+
video_embeds_sharded,
1387+
audio_embeds_sharded,
1388+
new_attention_mask,
1389+
new_attention_mask,
1390+
guidance_scale > 1.0,
1391+
guidance_scale,
1392+
latent_num_frames,
1393+
latent_height,
1394+
latent_width,
1395+
audio_num_frames,
1396+
frame_rate,
1397+
)
1398+
1399+
if guidance_scale > 1.0:
1400+
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1401+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1402+
1403+
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
1404+
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1405+
1406+
latents_step = latents_jax[batch_size:]
1407+
audio_latents_step = audio_latents_jax[batch_size:]
1408+
else:
1409+
latents_step = latents_jax
1410+
audio_latents_step = audio_latents_jax
1411+
1412+
latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False)
1413+
latents_step = latents_step.astype(jnp.float32)
1414+
1415+
audio_latents_step, _ = self.scheduler.step(
1416+
scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False
1417+
)
1418+
audio_latents_step = audio_latents_step.astype(jnp.float32)
1419+
1420+
if guidance_scale > 1.0:
1421+
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
1422+
audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)
1423+
else:
1424+
latents_jax = latents_step
1425+
audio_latents_jax = audio_latents_step
13701426

13711427
# 8. Decode Latents
13721428
if guidance_scale > 1.0:

0 commit comments

Comments
 (0)