@@ -87,7 +87,7 @@ def get_git_commit_hash():
8787
8888def call_pipeline (config , pipeline , prompt , negative_prompt , num_inference_steps = None ):
8989 """Call the pipeline with optional num_inference_steps override.
90-
90+
9191 Args:
9292 config: The configuration object.
9393 pipeline: The pipeline to call.
@@ -290,25 +290,31 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
290290 if config .enable_profiler :
291291 skip_steps = getattr (config , 'skip_first_n_steps_for_profiler' , 0 )
292292 profiler_steps = getattr (config , 'profiler_steps' , config .num_inference_steps )
293-
294- max_logging .log (f"Profiler: skip_first_n_steps={ skip_steps } , profiler_steps={ profiler_steps } " )
295-
293+ profile_all = profiler_steps == - 1
294+ steps_for_profile = config .num_inference_steps if profile_all else profiler_steps
295+
296+ if profile_all :
297+ max_logging .log (f"Profiler: profiling all { steps_for_profile } inference steps (profiler_steps=-1)" )
298+ else :
299+ max_logging .log (f"Profiler: profiling { steps_for_profile } steps out of { config .num_inference_steps } total" )
300+ max_logging .log (f"Profiler: skip_first_n_steps={ skip_steps } " )
301+
296302 def block_if_jax (x ):
297303 """Block until ready if x is a JAX array, otherwise no-op."""
298304 if hasattr (x , 'block_until_ready' ):
299305 x .block_until_ready ()
300306 return x
301-
307+
302308 for i in range (skip_steps ):
303309 max_logging .log (f"Profiler warmup iteration { i + 1 } /{ skip_steps } " )
304- warmup_videos = call_pipeline (config , pipeline , prompt , negative_prompt , num_inference_steps = profiler_steps )
310+ warmup_videos = call_pipeline (config , pipeline , prompt , negative_prompt , num_inference_steps = steps_for_profile )
305311 # Block until warmup completes
306312 jax .tree_util .tree_map (block_if_jax , warmup_videos )
307-
313+
308314 s0 = time .perf_counter ()
309315 max_utils .activate_profiler (config )
310- max_logging .log (f"Profiler: starting profiled run with { profiler_steps } steps" )
311- profiled_videos = call_pipeline (config , pipeline , prompt , negative_prompt , num_inference_steps = profiler_steps )
316+ max_logging .log (f"Profiler: starting profiled run with { steps_for_profile } steps" )
317+ profiled_videos = call_pipeline (config , pipeline , prompt , negative_prompt , num_inference_steps = steps_for_profile )
312318 # Wait for all computation to finish before stopping profiler
313319 jax .tree_util .tree_map (block_if_jax , profiled_videos )
314320 max_utils .deactivate_profiler (config )
0 commit comments