Skip to content

Commit fd36989

Browse files
committed
Added tensorboard logging for inference metrics
1 parent 4896870 commit fd36989

1 file changed

Lines changed: 27 additions & 3 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
141141
def run(config, pipeline=None, filename_prefix=""):
142142
print("seed: ", config.seed)
143143
model_key = config.model_name
144+
config.tensorboard_dir = os.path.join(config.output_dir, "tensorboard")
145+
# Initialize TensorBoard writer
146+
writer = max_utils.initialize_summary_writer(config)
147+
if jax.process_index() == 0 and writer:
148+
max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}")
144149

145150
checkpointer_lib = get_checkpointer(model_key)
146151
WanCheckpointer = checkpointer_lib.WanCheckpointer
@@ -164,7 +169,10 @@ def run(config, pipeline=None, filename_prefix=""):
164169

165170
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
166171

167-
print("compile time: ", (time.perf_counter() - s0))
172+
compile_time = time.perf_counter() - s0
173+
print("compile_time: ", compile_time)
174+
if writer and jax.process_index() == 0:
175+
writer.add_scalar("inference/compile_time", compile_time, global_step=0)
168176
saved_video_path = []
169177
for i in range(len(videos)):
170178
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
@@ -175,14 +183,30 @@ def run(config, pipeline=None, filename_prefix=""):
175183

176184
s0 = time.perf_counter()
177185
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
178-
print("generation time: ", (time.perf_counter() - s0))
186+
generation_time = time.perf_counter() - s0
187+
print("generation_time: ", generation_time)
188+
if writer and jax.process_index() == 0:
189+
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
190+
num_devices = jax.device_count()
191+
num_videos = num_devices * config.per_device_batch_size
192+
if num_videos > 0:
193+
generation_time_per_video = generation_time / num_videos
194+
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
195+
print(f"generation time per video: {generation_time_per_video}")
196+
else:
197+
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
198+
179199

180200
s0 = time.perf_counter()
181201
if config.enable_profiler:
182202
max_utils.activate_profiler(config)
183203
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
184204
max_utils.deactivate_profiler(config)
185-
print("generation time: ", (time.perf_counter() - s0))
205+
generation_time_with_profiler = time.perf_counter() - s0
206+
print("generation_time_with_profiler: ", generation_time_with_profiler)
207+
if writer and jax.process_index() == 0:
208+
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
209+
186210
return saved_video_path
187211

188212

0 commit comments

Comments
 (0)