Skip to content

Commit 2d4eae1

Browse files
committed
Changed profiling logic
1 parent 5b91824 commit 2d4eae1

1 file changed

Lines changed: 40 additions & 7 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,20 @@ def get_git_commit_hash():
8585
jax.config.update("jax_use_shardy_partitioner", True)
8686

8787

88-
def call_pipeline(config, pipeline, prompt, negative_prompt):
88+
def call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=None):
89+
"""Call the pipeline with optional num_inference_steps override.
90+
91+
Args:
92+
config: The configuration object.
93+
pipeline: The pipeline to call.
94+
prompt: The prompt(s) to use.
95+
negative_prompt: The negative prompt(s) to use.
96+
num_inference_steps: Optional override for number of inference steps.
97+
If None, uses config.num_inference_steps.
98+
"""
8999
model_key = config.model_name
90100
model_type = config.model_type
101+
steps = num_inference_steps if num_inference_steps is not None else config.num_inference_steps
91102
if model_type == "I2V":
92103
image = load_image(config.image_url)
93104
if model_key == WAN2_1:
@@ -98,7 +109,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
98109
height=config.height,
99110
width=config.width,
100111
num_frames=config.num_frames,
101-
num_inference_steps=config.num_inference_steps,
112+
num_inference_steps=steps,
102113
guidance_scale=config.guidance_scale,
103114
)
104115
elif model_key == WAN2_2:
@@ -109,7 +120,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
109120
height=config.height,
110121
width=config.width,
111122
num_frames=config.num_frames,
112-
num_inference_steps=config.num_inference_steps,
123+
num_inference_steps=steps,
113124
guidance_scale_low=config.guidance_scale_low,
114125
guidance_scale_high=config.guidance_scale_high,
115126
)
@@ -123,7 +134,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
123134
height=config.height,
124135
width=config.width,
125136
num_frames=config.num_frames,
126-
num_inference_steps=config.num_inference_steps,
137+
num_inference_steps=steps,
127138
guidance_scale=config.guidance_scale,
128139
use_cfg_cache=config.use_cfg_cache,
129140
)
@@ -134,7 +145,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
134145
height=config.height,
135146
width=config.width,
136147
num_frames=config.num_frames,
137-
num_inference_steps=config.num_inference_steps,
148+
num_inference_steps=steps,
138149
guidance_scale_low=config.guidance_scale_low,
139150
guidance_scale_high=config.guidance_scale_high,
140151
)
@@ -275,15 +286,37 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
275286
max_logging.log(f"generation time per video: {generation_time_per_video}")
276287
else:
277288
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
278-
s0 = time.perf_counter()
289+
279290
if config.enable_profiler:
291+
skip_steps = getattr(config, 'skip_first_n_steps_for_profiler', 0)
292+
profiler_steps = getattr(config, 'profiler_steps', config.num_inference_steps)
293+
294+
max_logging.log(f"Profiler: skip_first_n_steps={skip_steps}, profiler_steps={profiler_steps}")
295+
296+
def block_if_jax(x):
297+
"""Block until ready if x is a JAX array, otherwise no-op."""
298+
if hasattr(x, 'block_until_ready'):
299+
x.block_until_ready()
300+
return x
301+
302+
for i in range(skip_steps):
303+
max_logging.log(f"Profiler warmup iteration {i + 1}/{skip_steps}")
304+
warmup_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=profiler_steps)
305+
# Block until warmup completes
306+
jax.tree_util.tree_map(block_if_jax, warmup_videos)
307+
308+
s0 = time.perf_counter()
280309
max_utils.activate_profiler(config)
281-
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
310+
max_logging.log(f"Profiler: starting profiled run with {profiler_steps} steps")
311+
profiled_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=profiler_steps)
312+
# Wait for all computation to finish before stopping profiler
313+
jax.tree_util.tree_map(block_if_jax, profiled_videos)
282314
max_utils.deactivate_profiler(config)
283315
generation_time_with_profiler = time.perf_counter() - s0
284316
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
285317
if writer and jax.process_index() == 0:
286318
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
319+
max_logging.log("Profiler: completed (video not saved)")
287320

288321
return saved_video_path
289322

0 commit comments

Comments
 (0)