@@ -85,9 +85,20 @@ def get_git_commit_hash():
8585jax .config .update ("jax_use_shardy_partitioner" , True )
8686
8787
88- def call_pipeline (config , pipeline , prompt , negative_prompt ):
88+ def call_pipeline (config , pipeline , prompt , negative_prompt , num_inference_steps = None ):
89+ """Call the pipeline with optional num_inference_steps override.
90+
91+ Args:
92+ config: The configuration object.
93+ pipeline: The pipeline to call.
94+ prompt: The prompt(s) to use.
95+ negative_prompt: The negative prompt(s) to use.
96+ num_inference_steps: Optional override for number of inference steps.
97+ If None, uses config.num_inference_steps.
98+ """
8999 model_key = config .model_name
90100 model_type = config .model_type
101+ steps = num_inference_steps if num_inference_steps is not None else config .num_inference_steps
91102 if model_type == "I2V" :
92103 image = load_image (config .image_url )
93104 if model_key == WAN2_1 :
@@ -98,7 +109,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
98109 height = config .height ,
99110 width = config .width ,
100111 num_frames = config .num_frames ,
101- num_inference_steps = config . num_inference_steps ,
112+ num_inference_steps = steps ,
102113 guidance_scale = config .guidance_scale ,
103114 )
104115 elif model_key == WAN2_2 :
@@ -109,7 +120,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
109120 height = config .height ,
110121 width = config .width ,
111122 num_frames = config .num_frames ,
112- num_inference_steps = config . num_inference_steps ,
123+ num_inference_steps = steps ,
113124 guidance_scale_low = config .guidance_scale_low ,
114125 guidance_scale_high = config .guidance_scale_high ,
115126 )
@@ -123,7 +134,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
123134 height = config .height ,
124135 width = config .width ,
125136 num_frames = config .num_frames ,
126- num_inference_steps = config . num_inference_steps ,
137+ num_inference_steps = steps ,
127138 guidance_scale = config .guidance_scale ,
128139 use_cfg_cache = config .use_cfg_cache ,
129140 )
@@ -134,7 +145,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
134145 height = config .height ,
135146 width = config .width ,
136147 num_frames = config .num_frames ,
137- num_inference_steps = config . num_inference_steps ,
148+ num_inference_steps = steps ,
138149 guidance_scale_low = config .guidance_scale_low ,
139150 guidance_scale_high = config .guidance_scale_high ,
140151 )
@@ -275,15 +286,37 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
275286 max_logging .log (f"generation time per video: { generation_time_per_video } " )
276287 else :
277288 max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
278- s0 = time . perf_counter ()
289+
279290 if config .enable_profiler :
291+ skip_steps = getattr (config , 'skip_first_n_steps_for_profiler' , 0 )
292+ profiler_steps = getattr (config , 'profiler_steps' , config .num_inference_steps )
293+
294+ max_logging .log (f"Profiler: skip_first_n_steps={ skip_steps } , profiler_steps={ profiler_steps } " )
295+
296+ def block_if_jax (x ):
297+ """Block until ready if x is a JAX array, otherwise no-op."""
298+ if hasattr (x , 'block_until_ready' ):
299+ x .block_until_ready ()
300+ return x
301+
302+ for i in range (skip_steps ):
303+ max_logging .log (f"Profiler warmup iteration { i + 1 } /{ skip_steps } " )
304+ warmup_videos = call_pipeline (config , pipeline , prompt , negative_prompt , num_inference_steps = profiler_steps )
305+ # Block until warmup completes
306+ jax .tree_util .tree_map (block_if_jax , warmup_videos )
307+
308+ s0 = time .perf_counter ()
280309 max_utils .activate_profiler (config )
281- videos = call_pipeline (config , pipeline , prompt , negative_prompt )
310+ max_logging .log (f"Profiler: starting profiled run with { profiler_steps } steps" )
311+ profiled_videos = call_pipeline (config , pipeline , prompt , negative_prompt , num_inference_steps = profiler_steps )
312+ # Wait for all computation to finish before stopping profiler
313+ jax .tree_util .tree_map (block_if_jax , profiled_videos )
282314 max_utils .deactivate_profiler (config )
283315 generation_time_with_profiler = time .perf_counter () - s0
284316 max_logging .log (f"generation_time_with_profiler: { generation_time_with_profiler } " )
285317 if writer and jax .process_index () == 0 :
286318 writer .add_scalar ("inference/generation_time_with_profiler" , generation_time_with_profiler , global_step = 0 )
319+ max_logging .log ("Profiler: completed (video not saved)" )
287320
288321 return saved_video_path
289322
0 commit comments