Skip to content

Commit 438fefd

Browse files
committed
Format fix
1 parent 2d4eae1 commit 438fefd

2 files changed

Lines changed: 17 additions & 11 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def get_git_commit_hash():
8787

8888
def call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=None):
8989
"""Call the pipeline with optional num_inference_steps override.
90-
90+
9191
Args:
9292
config: The configuration object.
9393
pipeline: The pipeline to call.
@@ -290,25 +290,31 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
290290
if config.enable_profiler:
291291
skip_steps = getattr(config, 'skip_first_n_steps_for_profiler', 0)
292292
profiler_steps = getattr(config, 'profiler_steps', config.num_inference_steps)
293-
294-
max_logging.log(f"Profiler: skip_first_n_steps={skip_steps}, profiler_steps={profiler_steps}")
295-
293+
profile_all = profiler_steps == -1
294+
steps_for_profile = config.num_inference_steps if profile_all else profiler_steps
295+
296+
if profile_all:
297+
max_logging.log(f"Profiler: profiling all {steps_for_profile} inference steps (profiler_steps=-1)")
298+
else:
299+
max_logging.log(f"Profiler: profiling {steps_for_profile} steps out of {config.num_inference_steps} total")
300+
max_logging.log(f"Profiler: skip_first_n_steps={skip_steps}")
301+
296302
def block_if_jax(x):
297303
"""Block until ready if x is a JAX array, otherwise no-op."""
298304
if hasattr(x, 'block_until_ready'):
299305
x.block_until_ready()
300306
return x
301-
307+
302308
for i in range(skip_steps):
303309
max_logging.log(f"Profiler warmup iteration {i + 1}/{skip_steps}")
304-
warmup_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=profiler_steps)
310+
warmup_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=steps_for_profile)
305311
# Block until warmup completes
306312
jax.tree_util.tree_map(block_if_jax, warmup_videos)
307-
313+
308314
s0 = time.perf_counter()
309315
max_utils.activate_profiler(config)
310-
max_logging.log(f"Profiler: starting profiled run with {profiler_steps} steps")
311-
profiled_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=profiler_steps)
316+
max_logging.log(f"Profiler: starting profiled run with {steps_for_profile} steps")
317+
profiled_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=steps_for_profile)
312318
# Wait for all computation to finish before stopping profiler
313319
jax.tree_util.tree_map(block_if_jax, profiled_videos)
314320
max_utils.deactivate_profiler(config)

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,13 +1088,13 @@ def __call__(
10881088
query_proj = _unflatten_heads(query_proj, self.heads)
10891089
key_proj = _unflatten_heads(key_proj, self.heads)
10901090
value_proj = _unflatten_heads(value_proj, self.heads)
1091-
1091+
10921092
# Enforce sequence parallelism on the new axis 2 (LENGTH) before doing the ROPE math
10931093
axis_names_qkv = nn.logical_to_mesh_axes((BATCH, HEAD, LENGTH, D_KV))
10941094
query_proj = jax.lax.with_sharding_constraint(query_proj, axis_names_qkv)
10951095
key_proj = jax.lax.with_sharding_constraint(key_proj, axis_names_qkv)
10961096
value_proj = jax.lax.with_sharding_constraint(value_proj, axis_names_qkv)
1097-
1097+
10981098
# output of _unflatten_heads Batch, heads, seq_len, head_dim
10991099
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
11001100

0 commit comments

Comments
 (0)