Skip to content

Commit 6f97ff3

Browse files
committed
fix
1 parent dbee23c commit 6f97ff3

1 file changed

Lines changed: 61 additions & 50 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 61 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import flax.linen as nn
2828
import flax.traverse_util
2929
from flax import nnx
30+
from flax.linen import partitioning as nn_partitioning
3031
from transformers import AutoTokenizer, GemmaTokenizer, GemmaTokenizerFast, Gemma3ForConditionalGeneration
3132
from tqdm.auto import tqdm
3233
import qwix
@@ -1222,59 +1223,69 @@ def __call__(
12221223
graphdef, state = nnx.split(self.transformer)
12231224

12241225
# 7. Denoising Loop
1225-
connectors_graphdef, connectors_state = nnx.split(self.connectors)
1226-
1227-
@jax.jit
1228-
def run_connectors(graphdef, state, hidden_states, attention_mask):
1229-
model = nnx.merge(graphdef, state)
1230-
return model(hidden_states, attention_mask)
1231-
1232-
video_embeds, audio_embeds = run_connectors(
1233-
connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_)
1226+
import contextlib
1227+
context_manager = (
1228+
self.mesh if hasattr(self, "mesh") and self.mesh is not None else contextlib.nullcontext()
1229+
)
1230+
axis_rules_context = (
1231+
nn_partitioning.axis_rules(self.config.logical_axis_rules)
1232+
if hasattr(self, "config") and hasattr(self.config, "logical_axis_rules") else contextlib.nullcontext()
12341233
)
12351234

1236-
for i, t in enumerate(timesteps):
1237-
noise_pred, noise_pred_audio = transformer_forward_pass(
1238-
graphdef, state,
1239-
latents_jax,
1240-
audio_latents_jax,
1241-
t,
1242-
video_embeds,
1243-
audio_embeds,
1244-
prompt_attention_mask_jax,
1245-
prompt_attention_mask_jax,
1246-
guidance_scale > 1.0,
1247-
guidance_scale,
1248-
num_frames,
1249-
height,
1250-
width,
1251-
audio_num_frames,
1252-
frame_rate,
1235+
with context_manager, axis_rules_context:
1236+
connectors_graphdef, connectors_state = nnx.split(self.connectors)
1237+
1238+
@jax.jit
1239+
def run_connectors(graphdef, state, hidden_states, attention_mask):
1240+
model = nnx.merge(graphdef, state)
1241+
return model(hidden_states, attention_mask)
1242+
1243+
video_embeds, audio_embeds = run_connectors(
1244+
connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_)
12531245
)
1254-
1255-
if guidance_scale > 1.0:
1256-
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1257-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1258-
# Audio guidance
1259-
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
1260-
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1261-
1262-
latents_step = latents_jax[batch_size:]
1263-
audio_latents_step = audio_latents_jax[batch_size:]
1264-
else:
1265-
latents_step = latents_jax
1266-
audio_latents_step = audio_latents_jax
1267-
1268-
# Step
1269-
latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False)
1270-
audio_latents_step, _ = self.scheduler.step(scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False)
1271-
1272-
if guidance_scale > 1.0:
1273-
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
1274-
audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)
1275-
else:
1276-
latents_jax = latents_step
1277-
audio_latents_jax = audio_latents_step
1246+
1247+
for i, t in enumerate(timesteps):
1248+
noise_pred, noise_pred_audio = transformer_forward_pass(
1249+
graphdef, state,
1250+
latents_jax,
1251+
audio_latents_jax,
1252+
t,
1253+
video_embeds,
1254+
audio_embeds,
1255+
prompt_attention_mask_jax,
1256+
prompt_attention_mask_jax,
1257+
guidance_scale > 1.0,
1258+
guidance_scale,
1259+
num_frames,
1260+
height,
1261+
width,
1262+
audio_num_frames,
1263+
frame_rate,
1264+
)
1265+
1266+
if guidance_scale > 1.0:
1267+
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1268+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1269+
# Audio guidance
1270+
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
1271+
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1272+
1273+
latents_step = latents_jax[batch_size:]
1274+
audio_latents_step = audio_latents_jax[batch_size:]
1275+
else:
1276+
latents_step = latents_jax
1277+
audio_latents_step = audio_latents_jax
1278+
1279+
# Step
1280+
latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False)
1281+
audio_latents_step, _ = self.scheduler.step(scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False)
1282+
1283+
if guidance_scale > 1.0:
1284+
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
1285+
audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)
1286+
else:
1287+
latents_jax = latents_step
1288+
audio_latents_jax = audio_latents_step
12781289

12791290
# 8. Decode Latents
12801291
if guidance_scale > 1.0:

0 commit comments

Comments
 (0)