Skip to content

Commit 6ba53e3

Browse files
committed
refactor for full jit of diffusion loop
1 parent 00c1609 commit 6ba53e3

1 file changed

Lines changed: 144 additions & 56 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 144 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,62 +1346,27 @@ def __call__(
13461346
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
13471347

13481348
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1349-
for i in range(len(timesteps_jax)):
1350-
t = timesteps_jax[i]
1351-
1352-
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
1353-
latents_jax_sharded = latents_jax
1354-
audio_latents_jax_sharded = audio_latents_jax
1355-
1356-
if not self.transformer.scan_layers:
1357-
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
1358-
latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names)
1359-
audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names)
1360-
1361-
noise_pred, noise_pred_audio = transformer_forward_pass(
1362-
graphdef,
1363-
state,
1364-
latents_jax_sharded,
1365-
audio_latents_jax_sharded,
1366-
t,
1367-
video_embeds_sharded,
1368-
audio_embeds_sharded,
1369-
new_attention_mask,
1370-
new_attention_mask,
1371-
guidance_scale > 1.0,
1372-
guidance_scale,
1373-
latent_num_frames,
1374-
latent_height,
1375-
latent_width,
1376-
audio_num_frames,
1377-
frame_rate,
1378-
)
1379-
1380-
if guidance_scale > 1.0:
1381-
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1382-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1383-
# Audio guidance
1384-
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
1385-
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1386-
1387-
latents_step = latents_jax[batch_size:]
1388-
audio_latents_step = audio_latents_jax[batch_size:]
1389-
else:
1390-
latents_step = latents_jax
1391-
audio_latents_step = audio_latents_jax
1392-
1393-
# Step
1394-
latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False)
1395-
audio_latents_step, _ = self.scheduler.step(
1396-
scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False
1397-
)
1398-
1399-
if guidance_scale > 1.0:
1400-
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
1401-
audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)
1402-
else:
1403-
latents_jax = latents_step
1404-
audio_latents_jax = audio_latents_step
1349+
latents_jax, audio_latents_jax = run_diffusion_loop(
1350+
graphdef,
1351+
state,
1352+
scheduler_state,
1353+
timesteps_jax,
1354+
latents_jax,
1355+
audio_latents_jax,
1356+
video_embeds_sharded,
1357+
audio_embeds_sharded,
1358+
new_attention_mask,
1359+
guidance_scale,
1360+
latent_num_frames,
1361+
latent_height,
1362+
latent_width,
1363+
audio_num_frames,
1364+
frame_rate,
1365+
batch_size,
1366+
self.transformer.scan_layers,
1367+
self.scheduler.step,
1368+
tuple(self.config.logical_axis_rules),
1369+
)
14051370

14061371
# 8. Decode Latents
14071372
if guidance_scale > 1.0:
@@ -1574,3 +1539,126 @@ def transformer_forward_pass(
15741539
)
15751540

15761541
return noise_pred, noise_pred_audio
1542+
1543+
1544+
@partial(
1545+
jax.jit,
1546+
static_argnames=(
1547+
"guidance_scale",
1548+
"latent_num_frames",
1549+
"latent_height",
1550+
"latent_width",
1551+
"audio_num_frames",
1552+
"fps",
1553+
"batch_size",
1554+
"scan_layers",
1555+
"scheduler_step",
1556+
"logical_axis_rules",
1557+
),
1558+
)
1559+
def run_diffusion_loop(
1560+
graphdef,
1561+
state,
1562+
scheduler_state,
1563+
timesteps_jax,
1564+
latents_jax,
1565+
audio_latents_jax,
1566+
video_embeds_sharded,
1567+
audio_embeds_sharded,
1568+
new_attention_mask,
1569+
guidance_scale,
1570+
latent_num_frames,
1571+
latent_height,
1572+
latent_width,
1573+
audio_num_frames,
1574+
fps,
1575+
batch_size,
1576+
scan_layers,
1577+
scheduler_step,
1578+
logical_axis_rules,
1579+
):
1580+
transformer = nnx.merge(graphdef, state)
1581+
1582+
def scan_body(carry, t):
1583+
latents, audio_latents, s_state = carry
1584+
1585+
with nn_partitioning.axis_rules(logical_axis_rules):
1586+
latents_sharded = latents
1587+
audio_latents_sharded = audio_latents
1588+
1589+
if not scan_layers:
1590+
activation_axis_names = nn.logical_to_mesh_axes(
1591+
("activation_batch", "activation_length", "activation_embed")
1592+
)
1593+
latents_sharded = jax.lax.with_sharding_constraint(
1594+
latents, activation_axis_names
1595+
)
1596+
audio_latents_sharded = jax.lax.with_sharding_constraint(
1597+
audio_latents, activation_axis_names
1598+
)
1599+
1600+
# Expand timestep to batch size
1601+
t_expanded = jnp.expand_dims(t, 0).repeat(latents.shape[0])
1602+
1603+
noise_pred, noise_pred_audio = transformer(
1604+
hidden_states=latents_sharded,
1605+
encoder_hidden_states=video_embeds_sharded,
1606+
timestep=t_expanded,
1607+
encoder_attention_mask=new_attention_mask,
1608+
num_frames=latent_num_frames,
1609+
height=latent_height,
1610+
width=latent_width,
1611+
audio_hidden_states=audio_latents_sharded,
1612+
audio_encoder_hidden_states=audio_embeds_sharded,
1613+
audio_encoder_attention_mask=new_attention_mask,
1614+
fps=fps,
1615+
audio_num_frames=audio_num_frames,
1616+
return_dict=False,
1617+
)
1618+
1619+
if guidance_scale > 1.0:
1620+
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1621+
noise_pred = noise_pred_uncond + guidance_scale * (
1622+
noise_pred_text - noise_pred_uncond
1623+
)
1624+
# Audio guidance
1625+
(
1626+
noise_pred_audio_uncond,
1627+
noise_pred_audio_text,
1628+
) = jnp.split(noise_pred_audio, 2, axis=0)
1629+
noise_pred_audio = (
1630+
noise_pred_audio_uncond
1631+
+ guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1632+
)
1633+
1634+
latents_step = latents[batch_size:]
1635+
audio_latents_step = audio_latents[batch_size:]
1636+
else:
1637+
latents_step = latents
1638+
audio_latents_step = audio_latents
1639+
1640+
# Step scheduler
1641+
latents_step, _ = scheduler_step(
1642+
s_state, noise_pred, t, latents_step, return_dict=False
1643+
)
1644+
audio_latents_step, _ = scheduler_step(
1645+
s_state, noise_pred_audio, t, audio_latents_step, return_dict=False
1646+
)
1647+
1648+
if guidance_scale > 1.0:
1649+
latents_next = jnp.concatenate([latents_step] * 2, axis=0)
1650+
audio_latents_next = jnp.concatenate([audio_latents_step] * 2, axis=0)
1651+
else:
1652+
latents_next = latents_step
1653+
audio_latents_next = audio_latents_step
1654+
1655+
new_carry = (latents_next, audio_latents_next, s_state)
1656+
return new_carry, None
1657+
1658+
# Initial carry
1659+
initial_carry = (latents_jax, audio_latents_jax, scheduler_state)
1660+
1661+
# Run scan
1662+
final_carry, _ = jax.lax.scan(scan_body, initial_carry, timesteps_jax)
1663+
1664+
return final_carry[0], final_carry[1]

0 commit comments

Comments
 (0)