Skip to content

Commit e5a9df2

Browse files
committed
Start profiling after warmup step
1 parent 8758e29 commit e5a9df2

2 files changed

Lines changed: 37 additions & 14 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,5 @@ compile_topology_num_slices: -1
103103
quantization_local_shard_count: -1
104104
use_qwix_quantization: False
105105
jit_initializers: True
106-
enable_single_replica_ckpt_restoring: False
106+
enable_single_replica_ckpt_restoring: False
107+
enable_profiler: True

src/maxdiffusion/generate_ltx_video.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
2121
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
2222
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
23-
from maxdiffusion import pyconfig, max_logging
23+
from maxdiffusion import pyconfig, max_logging, max_utils
2424
import torchvision.transforms.functional as TVF
2525
import imageio
2626
from datetime import datetime
@@ -29,6 +29,7 @@
2929
from pathlib import Path
3030
from PIL import Image
3131
import torch
32+
import jax
3233

3334

3435
def calculate_padding(
@@ -206,19 +207,40 @@ def run(config):
206207
else None
207208
)
208209

210+
pipeline_args = {
211+
"height": height_padded,
212+
"width": width_padded,
213+
"num_frames": num_frames_padded,
214+
"is_video": True,
215+
"output_type": "pt",
216+
"config": config,
217+
"enhance_prompt": enhance_prompt,
218+
"conditioning_items": conditioning_items,
219+
"seed": config.seed,
220+
}
221+
222+
223+
# Warm-up call
209224
s0 = time.perf_counter()
210-
images = pipeline(
211-
height=height_padded,
212-
width=width_padded,
213-
num_frames=num_frames_padded,
214-
is_video=True,
215-
output_type="pt",
216-
config=config,
217-
enhance_prompt=enhance_prompt,
218-
conditioning_items=conditioning_items,
219-
seed=config.seed,
220-
)
221-
max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.")
225+
images = pipeline(**pipeline_args)
226+
max_logging.log(f"Warmup time: {time.perf_counter() - s0:.1f}s.")
227+
228+
# Normal call
229+
s0 = time.perf_counter()
230+
images = pipeline(**pipeline_args)
231+
max_logging.log(f"Generation time: {time.perf_counter() - s0:.1f}s.")
232+
233+
# Profiled call
234+
if config.enable_profiler:
235+
profile_timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
236+
profiler_output_path = f"gs://hjajoo-ai-ninja-bucket/ltx-video/profiler_traces/{profile_timestamp}"
237+
jax.profiler.start_trace(profiler_output_path)
238+
max_logging.log(f"JAX profiler started. Traces will be saved to: {profiler_output_path}")
239+
s0 = time.perf_counter()
240+
images = pipeline(**pipeline_args)
241+
jax.profiler.stop_trace()
242+
max_logging.log(f"JAX profiler stopped.")
243+
max_logging.log(f"Generation time with profiler: {time.perf_counter() - s0:.1f}s.")
222244

223245
(pad_left, pad_right, pad_top, pad_bottom) = padding
224246
pad_bottom = -pad_bottom

0 commit comments

Comments
 (0)