Skip to content

Commit 4e4d799

Browse files
committed
Adding Tensorboard logging for inference metrics
1 parent dfcd0c0 commit 4e4d799

1 file changed

Lines changed: 12 additions & 14 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
139139

140140

141141
def run(config, pipeline=None, filename_prefix=""):
142-
print("seed: ", config.seed)
143142
model_key = config.model_name
144-
tensorboard_dir = os.path.join(config.output_dir, "tensorboard")
145143
# Initialize TensorBoard writer
146144
writer = max_utils.initialize_summary_writer(config)
147145
if jax.process_index() == 0 and writer:
@@ -168,17 +166,17 @@ def run(config, pipeline=None, filename_prefix=""):
168166
)
169167

170168
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
171-
print("===================== Model details =======================")
172-
print("model name: ", config.model_name)
173-
print("model path: ", config.pretrained_model_name_or_path)
174-
print("model type: t2v")
175-
print("hardware: ", jax.devices()[0].platform)
176-
print("number of devices: ", jax.device_count())
177-
print("per_device_batch_size: ", config.per_device_batch_size)
178-
print("============================================================")
169+
max_logging.log("===================== Model details =======================")
170+
max_logging.log("model name: ", config.model_name)
171+
max_logging.log("model path: ", config.pretrained_model_name_or_path)
172+
max_logging.log("model type: t2v")
173+
max_logging.log("hardware: ", jax.devices()[0].platform)
174+
max_logging.log("number of devices: ", jax.device_count())
175+
max_logging.log("per_device_batch_size: ", config.per_device_batch_size)
176+
max_logging.log("============================================================")
179177

180178
compile_time = time.perf_counter() - s0
181-
print("compile_time: ", compile_time)
179+
max_logging.log("compile_time: ", compile_time)
182180
if writer and jax.process_index() == 0:
183181
writer.add_scalar("inference/compile_time", compile_time, global_step=0)
184182
saved_video_path = []
@@ -192,15 +190,15 @@ def run(config, pipeline=None, filename_prefix=""):
192190
s0 = time.perf_counter()
193191
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
194192
generation_time = time.perf_counter() - s0
195-
print("generation_time: ", generation_time)
193+
max_logging.log("generation_time: ", generation_time)
196194
if writer and jax.process_index() == 0:
197195
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
198196
num_devices = jax.device_count()
199197
num_videos = num_devices * config.per_device_batch_size
200198
if num_videos > 0:
201199
generation_time_per_video = generation_time / num_videos
202200
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
203-
print(f"generation time per video: {generation_time_per_video}")
201+
max_logging.log(f"generation time per video: {generation_time_per_video}")
204202
else:
205203
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
206204

@@ -211,7 +209,7 @@ def run(config, pipeline=None, filename_prefix=""):
211209
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
212210
max_utils.deactivate_profiler(config)
213211
generation_time_with_profiler = time.perf_counter() - s0
214-
print("generation_time_with_profiler: ", generation_time_with_profiler)
212+
max_logging.log("generation_time_with_profiler: ", generation_time_with_profiler)
215213
if writer and jax.process_index() == 0:
216214
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
217215

0 commit comments

Comments
 (0)