Skip to content

Commit 5f458fd

Browse files
committed
Adding Tensorboard logging for inference metrics
1 parent 448ee98 commit 5f458fd

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,16 @@ def run(config, pipeline=None, filename_prefix=""):
167167

168168
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
169169
max_logging.log("===================== Model details =======================")
170-
max_logging.log("model name: ", config.model_name)
171-
max_logging.log("model path: ", config.pretrained_model_name_or_path)
170+
max_logging.log(f"model name: {config.model_name}")
171+
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
172172
max_logging.log("model type: t2v")
173-
max_logging.log("hardware: ", jax.devices()[0].platform)
174-
max_logging.log("number of devices: ", jax.device_count())
175-
max_logging.log("per_device_batch_size: ", config.per_device_batch_size)
173+
max_logging.log(f"hardware: {jax.devices()[0].platform}")
174+
max_logging.log(f"number of devices: {jax.device_count()}")
175+
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
176176
max_logging.log("============================================================")
177177

178178
compile_time = time.perf_counter() - s0
179-
max_logging.log("compile_time: ", compile_time)
179+
max_logging.log(f"compile_time: {compile_time}")
180180
if writer and jax.process_index() == 0:
181181
writer.add_scalar("inference/compile_time", compile_time, global_step=0)
182182
saved_video_path = []
@@ -190,7 +190,7 @@ def run(config, pipeline=None, filename_prefix=""):
190190
s0 = time.perf_counter()
191191
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
192192
generation_time = time.perf_counter() - s0
193-
max_logging.log("generation_time: ", generation_time)
193+
max_logging.log(f"generation_time: {generation_time}")
194194
if writer and jax.process_index() == 0:
195195
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
196196
num_devices = jax.device_count()
@@ -209,7 +209,7 @@ def run(config, pipeline=None, filename_prefix=""):
209209
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
210210
max_utils.deactivate_profiler(config)
211211
generation_time_with_profiler = time.perf_counter() - s0
212-
max_logging.log("generation_time_with_profiler: ", generation_time_with_profiler)
212+
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
213213
if writer and jax.process_index() == 0:
214214
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
215215

0 commit comments

Comments
 (0)