Skip to content

Commit 50bc6b0

Browse files
committed
markers placed for measuring e2e time for 10 steps
1 parent e1fdded commit 50bc6b0

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,10 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12301230
connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_)
12311231
)
12321232

1233+
import time
1234+
total_diffusion_time = 0.0
12331235
for i, t in enumerate(timesteps):
1236+
step_start_time = time.perf_counter()
12341237
noise_pred, noise_pred_audio = transformer_forward_pass(
12351238
graphdef,
12361239
state,
@@ -1276,6 +1279,12 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12761279
latents_jax = latents_step
12771280
audio_latents_jax = audio_latents_step
12781281

1282+
latents_jax.block_until_ready()
1283+
step_duration = time.perf_counter() - step_start_time
1284+
total_diffusion_time += step_duration
1285+
max_logging.log(f"[Tuning] Diffusion Step {i} e2e time: {step_duration:.4f} seconds")
1286+
max_logging.log(f"[Tuning] Total pure diffusion time (all steps): {total_diffusion_time:.4f} seconds")
1287+
12791288
# 8. Decode Latents
12801289
if guidance_scale > 1.0:
12811290
latents_jax = latents_jax[batch_size:]

0 commit comments

Comments
 (0)