Skip to content

Commit 76d9ce0

Browse files
committed
markers placed for measuring e2e time for 10 steps
1 parent 123340d commit 76d9ce0

2 files changed

Lines changed: 14 additions & 2 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ sampler: "from_checkpoint"
2424

2525
# Generation parameters
2626
global_batch_size_to_train_on: 1
27-
num_inference_steps: 40
27+
num_inference_steps: 10
2828
guidance_scale: 3.0
2929
fps: 24
3030
pipeline_type: multi-scale
@@ -58,6 +58,12 @@ data_sharding: ['data', 'fsdp', 'context', 'tensor']
5858

5959
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
6060
dcn_fsdp_parallelism: -1
61+
62+
flash_block_sizes: {
63+
block_q: 1024,
64+
block_kv: 1024,
65+
block_kv_compute: 1024
66+
}
6167
dcn_context_parallelism: 1
6268
dcn_tensor_parallelism: 1
6369
ici_data_parallelism: 1

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,9 +982,10 @@ def run_inference(
982982
conditioning_mask,
983983
original_conditioning_mask,
984984
):
985+
total_diffusion_time = 0.0
985986
for i, t in enumerate(scheduler_state.timesteps):
987+
step_start_time = time.perf_counter()
986988
current_timestep = t
987-
988989
latent_model_input = jnp.concatenate([latents] * num_conds) if num_conds > 1 else latents
989990

990991
if not isinstance(current_timestep, (jnp.ndarray, jax.Array)):
@@ -1058,7 +1059,12 @@ def run_inference(
10581059
latents, scheduler_state = denoising_step(
10591060
scheduler, scheduler_state, noise_pred, current_timestep, original_conditioning_mask, t, latents
10601061
)
1062+
latents.block_until_ready()
1063+
step_duration = time.perf_counter() - step_start_time
1064+
total_diffusion_time += step_duration
1065+
max_logging.log(f"[Tuning] Diffusion Step {i} e2e time: {step_duration:.4f} seconds")
10611066

1067+
max_logging.log(f"[Tuning] Total pure diffusion time (all steps): {total_diffusion_time:.4f} seconds")
10621068
return latents, scheduler_state
10631069

10641070

0 commit comments

Comments
 (0)