Skip to content

Commit 4ffd8c7

Browse files
committed
reformatted
1 parent b59c094 commit 4ffd8c7

1 file changed

Lines changed: 84 additions & 99 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 84 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -832,29 +832,29 @@ def encode_prompt(
832832
if do_classifier_free_guidance and negative_prompt_embeds is None:
833833
negative_prompt = negative_prompt or ""
834834
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
835-
835+
836836
if isinstance(prompt, str):
837837
prompt = [prompt]
838-
838+
839839
combined_prompts = prompt + negative_prompt
840-
840+
841841
combined_embeds, combined_mask = self._get_gemma_prompt_embeds(
842842
prompt=combined_prompts,
843843
num_videos_per_prompt=num_videos_per_prompt,
844844
max_sequence_length=max_sequence_length,
845845
scale_factor=scale_factor,
846846
dtype=dtype,
847847
)
848-
848+
849849
split_idx = batch_size * num_videos_per_prompt
850-
850+
851851
if isinstance(combined_embeds, list):
852852
prompt_embeds = [state[:split_idx] for state in combined_embeds]
853853
negative_prompt_embeds = [state[split_idx:] for state in combined_embeds]
854854
else:
855855
prompt_embeds = combined_embeds[:split_idx]
856856
negative_prompt_embeds = combined_embeds[split_idx:]
857-
857+
858858
prompt_attention_mask = combined_mask[:split_idx]
859859
negative_prompt_attention_mask = combined_mask[split_idx:]
860860
else:
@@ -865,7 +865,7 @@ def encode_prompt(
865865
scale_factor=scale_factor,
866866
dtype=dtype,
867867
)
868-
868+
869869
if do_classifier_free_guidance and negative_prompt_embeds is None:
870870
negative_prompt = negative_prompt or ""
871871
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
@@ -1577,95 +1577,80 @@ def run_diffusion_loop(
15771577
scheduler_step,
15781578
logical_axis_rules,
15791579
):
1580-
transformer = nnx.merge(graphdef, state)
1581-
1582-
def scan_body(carry, t, model):
1583-
latents, audio_latents, s_state = carry
1584-
1585-
with nn_partitioning.axis_rules(logical_axis_rules):
1586-
latents_sharded = latents
1587-
audio_latents_sharded = audio_latents
1588-
1589-
if not scan_layers:
1590-
activation_axis_names = nn.logical_to_mesh_axes(
1591-
("activation_batch", "activation_length", "activation_embed")
1592-
)
1593-
latents_sharded = jax.lax.with_sharding_constraint(
1594-
latents, activation_axis_names
1595-
)
1596-
audio_latents_sharded = jax.lax.with_sharding_constraint(
1597-
audio_latents, activation_axis_names
1598-
)
1599-
1600-
# Expand timestep to batch size
1601-
t_expanded = jnp.expand_dims(t, 0).repeat(latents.shape[0])
1602-
1603-
noise_pred, noise_pred_audio = model(
1604-
hidden_states=latents_sharded,
1605-
encoder_hidden_states=video_embeds_sharded,
1606-
timestep=t_expanded,
1607-
encoder_attention_mask=new_attention_mask,
1608-
num_frames=latent_num_frames,
1609-
height=latent_height,
1610-
width=latent_width,
1611-
audio_hidden_states=audio_latents_sharded,
1612-
audio_encoder_hidden_states=audio_embeds_sharded,
1613-
audio_encoder_attention_mask=new_attention_mask,
1614-
fps=fps,
1615-
audio_num_frames=audio_num_frames,
1616-
return_dict=False,
1617-
)
1618-
1619-
if guidance_scale > 1.0:
1620-
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1621-
noise_pred = noise_pred_uncond + guidance_scale * (
1622-
noise_pred_text - noise_pred_uncond
1623-
)
1624-
# Audio guidance
1625-
(
1626-
noise_pred_audio_uncond,
1627-
noise_pred_audio_text,
1628-
) = jnp.split(noise_pred_audio, 2, axis=0)
1629-
noise_pred_audio = (
1630-
noise_pred_audio_uncond
1631-
+ guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1632-
)
1633-
1634-
latents_step = latents[batch_size:]
1635-
audio_latents_step = audio_latents[batch_size:]
1636-
else:
1637-
latents_step = latents
1638-
audio_latents_step = audio_latents
1639-
1640-
# Step scheduler
1641-
latents_step, _ = scheduler_step(
1642-
s_state, noise_pred, t, latents_step, return_dict=False
1643-
)
1644-
latents_step = latents_step.astype(latents.dtype)
1645-
1646-
audio_latents_step, _ = scheduler_step(
1647-
s_state, noise_pred_audio, t, audio_latents_step, return_dict=False
1648-
)
1649-
audio_latents_step = audio_latents_step.astype(audio_latents.dtype)
1650-
1651-
if guidance_scale > 1.0:
1652-
latents_next = jnp.concatenate([latents_step] * 2, axis=0)
1653-
audio_latents_next = jnp.concatenate([audio_latents_step] * 2, axis=0)
1654-
else:
1655-
latents_next = latents_step
1656-
audio_latents_next = audio_latents_step
1657-
1658-
new_carry = (latents_next, audio_latents_next, s_state)
1659-
return new_carry, None
1660-
1661-
# Initial carry
1662-
initial_carry = (latents_jax, audio_latents_jax, scheduler_state)
1663-
1664-
# Run scan
1665-
final_carry, _ = nnx.scan(
1666-
scan_body,
1667-
in_axes=(nnx.Carry, 0, None),
1668-
out_axes=(nnx.Carry, 0),
1669-
)(initial_carry, timesteps_jax, transformer)
1670-
1671-
return final_carry[0], final_carry[1]
1580+
transformer = nnx.merge(graphdef, state)
1581+
1582+
def scan_body(carry, t, model):
1583+
latents, audio_latents, s_state = carry
1584+
1585+
with nn_partitioning.axis_rules(logical_axis_rules):
1586+
latents_sharded = latents
1587+
audio_latents_sharded = audio_latents
1588+
1589+
if not scan_layers:
1590+
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
1591+
latents_sharded = jax.lax.with_sharding_constraint(latents, activation_axis_names)
1592+
audio_latents_sharded = jax.lax.with_sharding_constraint(audio_latents, activation_axis_names)
1593+
1594+
# Expand timestep to batch size
1595+
t_expanded = jnp.expand_dims(t, 0).repeat(latents.shape[0])
1596+
1597+
noise_pred, noise_pred_audio = model(
1598+
hidden_states=latents_sharded,
1599+
encoder_hidden_states=video_embeds_sharded,
1600+
timestep=t_expanded,
1601+
encoder_attention_mask=new_attention_mask,
1602+
num_frames=latent_num_frames,
1603+
height=latent_height,
1604+
width=latent_width,
1605+
audio_hidden_states=audio_latents_sharded,
1606+
audio_encoder_hidden_states=audio_embeds_sharded,
1607+
audio_encoder_attention_mask=new_attention_mask,
1608+
fps=fps,
1609+
audio_num_frames=audio_num_frames,
1610+
return_dict=False,
1611+
)
1612+
1613+
if guidance_scale > 1.0:
1614+
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1615+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1616+
# Audio guidance
1617+
(
1618+
noise_pred_audio_uncond,
1619+
noise_pred_audio_text,
1620+
) = jnp.split(noise_pred_audio, 2, axis=0)
1621+
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1622+
1623+
latents_step = latents[batch_size:]
1624+
audio_latents_step = audio_latents[batch_size:]
1625+
else:
1626+
latents_step = latents
1627+
audio_latents_step = audio_latents
1628+
1629+
# Step scheduler
1630+
latents_step, _ = scheduler_step(s_state, noise_pred, t, latents_step, return_dict=False)
1631+
latents_step = latents_step.astype(latents.dtype)
1632+
1633+
audio_latents_step, _ = scheduler_step(s_state, noise_pred_audio, t, audio_latents_step, return_dict=False)
1634+
audio_latents_step = audio_latents_step.astype(audio_latents.dtype)
1635+
1636+
if guidance_scale > 1.0:
1637+
latents_next = jnp.concatenate([latents_step] * 2, axis=0)
1638+
audio_latents_next = jnp.concatenate([audio_latents_step] * 2, axis=0)
1639+
else:
1640+
latents_next = latents_step
1641+
audio_latents_next = audio_latents_step
1642+
1643+
new_carry = (latents_next, audio_latents_next, s_state)
1644+
return new_carry, None
1645+
1646+
# Initial carry
1647+
initial_carry = (latents_jax, audio_latents_jax, scheduler_state)
1648+
1649+
# Run scan
1650+
final_carry, _ = nnx.scan(
1651+
scan_body,
1652+
in_axes=(nnx.Carry, 0, None),
1653+
out_axes=(nnx.Carry, 0),
1654+
)(initial_carry, timesteps_jax, transformer)
1655+
1656+
return final_carry[0], final_carry[1]

0 commit comments

Comments
 (0)