|
20 | 20 | from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline |
21 | 21 | from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem |
22 | 22 | import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor |
23 | | -from maxdiffusion import pyconfig, max_logging |
| 23 | +from maxdiffusion import pyconfig, max_logging, max_utils |
24 | 24 | import torchvision.transforms.functional as TVF |
25 | 25 | import imageio |
26 | 26 | from datetime import datetime |
|
29 | 29 | from pathlib import Path |
30 | 30 | from PIL import Image |
31 | 31 | import torch |
| 32 | +import jax |
32 | 33 |
|
33 | 34 |
|
34 | 35 | def calculate_padding( |
@@ -206,19 +207,40 @@ def run(config): |
206 | 207 | else None |
207 | 208 | ) |
208 | 209 |
|
| 210 | + pipeline_args = { |
| 211 | + "height": height_padded, |
| 212 | + "width": width_padded, |
| 213 | + "num_frames": num_frames_padded, |
| 214 | + "is_video": True, |
| 215 | + "output_type": "pt", |
| 216 | + "config": config, |
| 217 | + "enhance_prompt": enhance_prompt, |
| 218 | + "conditioning_items": conditioning_items, |
| 219 | + "seed": config.seed, |
| 220 | + } |
| 221 | + |
| 222 | + |
| 223 | + # Warm-up call |
209 | 224 | s0 = time.perf_counter() |
210 | | - images = pipeline( |
211 | | - height=height_padded, |
212 | | - width=width_padded, |
213 | | - num_frames=num_frames_padded, |
214 | | - is_video=True, |
215 | | - output_type="pt", |
216 | | - config=config, |
217 | | - enhance_prompt=enhance_prompt, |
218 | | - conditioning_items=conditioning_items, |
219 | | - seed=config.seed, |
220 | | - ) |
221 | | - max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.") |
| 225 | + images = pipeline(**pipeline_args) |
| 226 | + max_logging.log(f"Warmup time: {time.perf_counter() - s0:.1f}s.") |
| 227 | + |
| 228 | + # Normal call |
| 229 | + s0 = time.perf_counter() |
| 230 | + images = pipeline(**pipeline_args) |
| 231 | + max_logging.log(f"Generation time: {time.perf_counter() - s0:.1f}s.") |
| 232 | + |
| 233 | + # Profiled call |
| 234 | + if config.enable_profiler: |
| 235 | + profile_timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
| 236 | + profiler_output_path = f"gs://hjajoo-ai-ninja-bucket/ltx-video/profiler_traces/{profile_timestamp}" |
| 237 | + jax.profiler.start_trace(profiler_output_path) |
| 238 | + max_logging.log(f"JAX profiler started. Traces will be saved to: {profiler_output_path}") |
| 239 | + s0 = time.perf_counter() |
| 240 | + images = pipeline(**pipeline_args) |
| 241 | + jax.profiler.stop_trace() |
| 242 | + max_logging.log(f"JAX profiler stopped.") |
| 243 | + max_logging.log(f"Generation time with profiler: {time.perf_counter() - s0:.1f}s.") |
222 | 244 |
|
223 | 245 | (pad_left, pad_right, pad_top, pad_bottom) = padding |
224 | 246 | pad_bottom = -pad_bottom |
|
0 commit comments