Skip to content

Commit f14d691

Browse files
committed
more annotations added with jax.named_scope
1 parent 25123c6 commit f14d691

3 files changed

Lines changed: 125 additions & 127 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,15 +1513,16 @@ def _temporal_tiled_decode(
15131513
return FlaxDecoderOutput(sample=dec)
15141514

15151515
def _encode(self, x: jax.Array, key: Optional[jax.Array] = None, causal: Optional[bool] = None) -> jax.Array:
1516-
B, T, H, W, C = x.shape
1517-
if self.use_framewise_decoding and T > self.tile_sample_min_num_frames:
1518-
return self._temporal_tiled_encode(x, key=key, causal=causal)
1516+
with jax.named_scope("VAE _encode"):
1517+
B, T, H, W, C = x.shape
1518+
if self.use_framewise_decoding and T > self.tile_sample_min_num_frames:
1519+
return self._temporal_tiled_encode(x, key=key, causal=causal)
15191520

1520-
if self.use_tiling and (W > self.tile_sample_min_width or H > self.tile_sample_min_height):
1521-
return self.tiled_encode(x, key=key, causal=causal)
1521+
if self.use_tiling and (W > self.tile_sample_min_width or H > self.tile_sample_min_height):
1522+
return self.tiled_encode(x, key=key, causal=causal)
15221523

1523-
enc = self.encoder(x, key=key, causal=causal)
1524-
return enc
1524+
enc = self.encoder(x, key=key, causal=causal)
1525+
return enc
15251526

15261527
def _decode(
15271528
self,
@@ -1531,22 +1532,23 @@ def _decode(
15311532
causal: Optional[bool] = None,
15321533
return_dict: bool = True,
15331534
) -> Union[FlaxDecoderOutput, jax.Array]:
1534-
B, T, H, W, C = z.shape
1535-
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1536-
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1537-
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1535+
with jax.named_scope("VAE _decode"):
1536+
B, T, H, W, C = z.shape
1537+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1538+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1539+
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
15381540

1539-
if self.use_framewise_decoding and T > tile_latent_min_num_frames:
1540-
return self._temporal_tiled_decode(z, temb, key=key, causal=causal, return_dict=return_dict)
1541+
if self.use_framewise_decoding and T > tile_latent_min_num_frames:
1542+
return self._temporal_tiled_decode(z, temb, key=key, causal=causal, return_dict=return_dict)
15411543

1542-
if self.use_tiling and (W > tile_latent_min_width or H > tile_latent_min_height):
1543-
return self.tiled_decode(z, temb, key=key, causal=causal, return_dict=return_dict)
1544+
if self.use_tiling and (W > tile_latent_min_width or H > tile_latent_min_height):
1545+
return self.tiled_decode(z, temb, key=key, causal=causal, return_dict=return_dict)
15441546

1545-
dec = self.decoder(z, temb, key=key, causal=causal)
1547+
dec = self.decoder(z, temb, key=key, causal=causal)
15461548

1547-
if not return_dict:
1548-
return (dec,)
1549-
return FlaxDecoderOutput(sample=dec)
1549+
if not return_dict:
1550+
return (dec,)
1551+
return FlaxDecoderOutput(sample=dec)
15501552

15511553
def encode(
15521554
self,

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,12 @@ def __call__(
108108
Returns:
109109
(video_embeds, audio_embeds, new_attention_mask)
110110
"""
111-
# 1. Shared Feature Extraction
112-
features = self.feature_extractor(hidden_states, attention_mask)
111+
with jax.named_scope("Text Encoder Forward"):
112+
# 1. Shared Feature Extraction
113+
features = self.feature_extractor(hidden_states, attention_mask)
113114

114-
# 2. Parallel Connection
115-
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
116-
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
115+
# 2. Parallel Connection
116+
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
117+
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
117118

118-
return video_embeds, audio_embeds, new_attention_mask
119+
return video_embeds, audio_embeds, new_attention_mask

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 97 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,19 +1096,18 @@ def __call__(
10961096
# 2. Encode inputs (Text)
10971097
import time
10981098
s_text = time.perf_counter()
1099-
with jax.profiler.TraceAnnotation("Encode Inputs"):
1100-
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
1101-
prompt,
1102-
negative_prompt,
1103-
do_classifier_free_guidance=guidance_scale > 1.0,
1104-
num_videos_per_prompt=num_videos_per_prompt,
1105-
prompt_embeds=prompt_embeds,
1106-
negative_prompt_embeds=negative_prompt_embeds,
1107-
prompt_attention_mask=prompt_attention_mask,
1108-
negative_prompt_attention_mask=negative_prompt_attention_mask,
1109-
max_sequence_length=max_sequence_length,
1110-
dtype=dtype,
1111-
)
1099+
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
1100+
prompt,
1101+
negative_prompt,
1102+
do_classifier_free_guidance=guidance_scale > 1.0,
1103+
num_videos_per_prompt=num_videos_per_prompt,
1104+
prompt_embeds=prompt_embeds,
1105+
negative_prompt_embeds=negative_prompt_embeds,
1106+
prompt_attention_mask=prompt_attention_mask,
1107+
negative_prompt_attention_mask=negative_prompt_attention_mask,
1108+
max_sequence_length=max_sequence_length,
1109+
dtype=dtype,
1110+
)
11121111
t_text = time.perf_counter() - s_text
11131112
max_logging.log(f"[Tuning] Prompt encoding took: {t_text:.4f} seconds")
11141113

@@ -1121,17 +1120,16 @@ def __call__(
11211120

11221121
key_latents, key_audio = jax.random.split(generator)
11231122

1124-
with jax.profiler.TraceAnnotation("Prepare Video Latents"):
1125-
latents = self.prepare_latents(
1126-
batch_size=batch_size,
1127-
height=height,
1128-
width=width,
1129-
num_frames=num_frames,
1130-
noise_scale=noise_scale,
1131-
dtype=dtype,
1132-
generator=key_latents,
1133-
latents=latents,
1134-
)
1123+
latents = self.prepare_latents(
1124+
batch_size=batch_size,
1125+
height=height,
1126+
width=width,
1127+
num_frames=num_frames,
1128+
noise_scale=noise_scale,
1129+
dtype=dtype,
1130+
generator=key_latents,
1131+
latents=latents,
1132+
)
11351133

11361134
latent_height = height // self.vae_spatial_compression_ratio
11371135
latent_width = width // self.vae_spatial_compression_ratio
@@ -1150,16 +1148,15 @@ def __call__(
11501148
)
11511149
audio_num_frames = round(duration_s * audio_latents_per_second)
11521150

1153-
with jax.profiler.TraceAnnotation("Prepare Audio Latents"):
1154-
audio_latents = self.prepare_audio_latents(
1155-
batch_size=batch_size,
1156-
num_channels_latents=audio_channels,
1157-
audio_latent_length=audio_num_frames,
1158-
noise_scale=noise_scale,
1159-
dtype=dtype,
1160-
generator=key_audio,
1161-
latents=audio_latents,
1162-
)
1151+
audio_latents = self.prepare_audio_latents(
1152+
batch_size=batch_size,
1153+
num_channels_latents=audio_channels,
1154+
audio_latent_length=audio_num_frames,
1155+
noise_scale=noise_scale,
1156+
dtype=dtype,
1157+
generator=key_audio,
1158+
latents=audio_latents,
1159+
)
11631160

11641161
# 5. Prepare Timesteps
11651162
sigmas = jnp.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
@@ -1242,57 +1239,56 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12421239
import time
12431240
total_diffusion_time = 0.0
12441241
for i, t in enumerate(timesteps):
1245-
with jax.profiler.TraceAnnotation(f"Diffusion Step {i}"):
1246-
step_start_time = time.perf_counter()
1247-
noise_pred, noise_pred_audio = transformer_forward_pass(
1248-
graphdef,
1249-
state,
1250-
latents_jax,
1251-
audio_latents_jax,
1252-
t,
1253-
video_embeds,
1254-
audio_embeds,
1255-
new_attention_mask,
1256-
new_attention_mask,
1257-
guidance_scale > 1.0,
1258-
guidance_scale,
1259-
latent_num_frames,
1260-
latent_height,
1261-
latent_width,
1262-
audio_num_frames,
1263-
frame_rate,
1264-
)
1265-
1266-
if guidance_scale > 1.0:
1267-
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1268-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1269-
# Audio guidance
1270-
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
1271-
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1272-
1273-
latents_step = latents_jax[batch_size:]
1274-
audio_latents_step = audio_latents_jax[batch_size:]
1275-
else:
1276-
latents_step = latents_jax
1277-
audio_latents_step = audio_latents_jax
1278-
1279-
# Step
1280-
latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False)
1281-
audio_latents_step, _ = self.scheduler.step(
1282-
scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False
1283-
)
1284-
1285-
if guidance_scale > 1.0:
1286-
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
1287-
audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)
1288-
else:
1289-
latents_jax = latents_step
1290-
audio_latents_jax = audio_latents_step
1291-
1292-
latents_jax.block_until_ready()
1293-
step_duration = time.perf_counter() - step_start_time
1294-
total_diffusion_time += step_duration
1295-
max_logging.log(f"[Tuning] Diffusion Step {i} e2e time: {step_duration:.4f} seconds")
1242+
step_start_time = time.perf_counter()
1243+
noise_pred, noise_pred_audio = transformer_forward_pass(
1244+
graphdef,
1245+
state,
1246+
latents_jax,
1247+
audio_latents_jax,
1248+
t,
1249+
video_embeds,
1250+
audio_embeds,
1251+
new_attention_mask,
1252+
new_attention_mask,
1253+
guidance_scale > 1.0,
1254+
guidance_scale,
1255+
latent_num_frames,
1256+
latent_height,
1257+
latent_width,
1258+
audio_num_frames,
1259+
frame_rate,
1260+
)
1261+
1262+
if guidance_scale > 1.0:
1263+
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1264+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1265+
# Audio guidance
1266+
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
1267+
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1268+
1269+
latents_step = latents_jax[batch_size:]
1270+
audio_latents_step = audio_latents_jax[batch_size:]
1271+
else:
1272+
latents_step = latents_jax
1273+
audio_latents_step = audio_latents_jax
1274+
1275+
# Step
1276+
latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False)
1277+
audio_latents_step, _ = self.scheduler.step(
1278+
scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False
1279+
)
1280+
1281+
if guidance_scale > 1.0:
1282+
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
1283+
audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)
1284+
else:
1285+
latents_jax = latents_step
1286+
audio_latents_jax = audio_latents_step
1287+
1288+
latents_jax.block_until_ready()
1289+
step_duration = time.perf_counter() - step_start_time
1290+
total_diffusion_time += step_duration
1291+
max_logging.log(f"[Tuning] Diffusion Step {i} e2e time: {step_duration:.4f} seconds")
12961292
max_logging.log(f"[Tuning] Total pure diffusion time (all steps): {total_diffusion_time:.4f} seconds")
12971293

12981294
# 8. Decode Latents
@@ -1354,27 +1350,26 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13541350
max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}")
13551351

13561352
s_vae = time.perf_counter()
1357-
with jax.profiler.TraceAnnotation("VAE Decode Video"):
1358-
if getattr(self.vae.config, "timestep_conditioning", False):
1359-
noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype)
1353+
if getattr(self.vae.config, "timestep_conditioning", False):
1354+
noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype)
13601355

1361-
if not isinstance(decode_timestep, list):
1362-
decode_timestep = [decode_timestep] * batch_size
1363-
if decode_noise_scale is None:
1364-
decode_noise_scale = decode_timestep
1365-
elif not isinstance(decode_noise_scale, list):
1366-
decode_noise_scale = [decode_noise_scale] * batch_size
1356+
if not isinstance(decode_timestep, list):
1357+
decode_timestep = [decode_timestep] * batch_size
1358+
if decode_noise_scale is None:
1359+
decode_noise_scale = decode_timestep
1360+
elif not isinstance(decode_noise_scale, list):
1361+
decode_noise_scale = [decode_noise_scale] * batch_size
13671362

1368-
timestep = jnp.array(decode_timestep, dtype=latents.dtype)
1369-
decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None]
1363+
timestep = jnp.array(decode_timestep, dtype=latents.dtype)
1364+
decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None]
13701365

1371-
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
1366+
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
13721367

1373-
latents = latents.astype(self.vae.dtype)
1374-
video = self.vae.decode(latents, temb=timestep, return_dict=False)[0]
1375-
else:
1376-
latents = latents.astype(self.vae.dtype)
1377-
video = self.vae.decode(latents, return_dict=False)[0]
1368+
latents = latents.astype(self.vae.dtype)
1369+
video = self.vae.decode(latents, temb=timestep, return_dict=False)[0]
1370+
else:
1371+
latents = latents.astype(self.vae.dtype)
1372+
video = self.vae.decode(latents, return_dict=False)[0]
13781373
t_vae = time.perf_counter() - s_vae
13791374
max_logging.log(f"[Tuning] VAE decoding took: {t_vae:.4f} seconds")
13801375
# Post-process video (converts to numpy/PIL)

0 commit comments

Comments
 (0)