@@ -141,6 +141,11 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
141141def run (config , pipeline = None , filename_prefix = "" ):
142142 print ("seed: " , config .seed )
143143 model_key = config .model_name
144+ config .tensorboard_dir = os .path .join (config .output_dir , "tensorboard" )
145+ # Initialize TensorBoard writer
146+ writer = max_utils .initialize_summary_writer (config )
147+ if jax .process_index () == 0 and writer :
148+ max_logging .log (f"TensorBoard logs will be written to: { config .tensorboard_dir } " )
144149
145150 checkpointer_lib = get_checkpointer (model_key )
146151 WanCheckpointer = checkpointer_lib .WanCheckpointer
@@ -164,7 +169,10 @@ def run(config, pipeline=None, filename_prefix=""):
164169
165170 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
166171
167- print ("compile time: " , (time .perf_counter () - s0 ))
172+ compile_time = time .perf_counter () - s0
173+ print ("compile_time: " , compile_time )
174+ if writer and jax .process_index () == 0 :
175+ writer .add_scalar ("inference/compile_time" , compile_time , global_step = 0 )
168176 saved_video_path = []
169177 for i in range (len (videos )):
170178 video_path = f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4"
@@ -175,14 +183,30 @@ def run(config, pipeline=None, filename_prefix=""):
175183
176184 s0 = time .perf_counter ()
177185 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
178- print ("generation time: " , (time .perf_counter () - s0 ))
186+ generation_time = time .perf_counter () - s0
187+ print ("generation_time: " , generation_time )
188+ if writer and jax .process_index () == 0 :
189+ writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
190+ num_devices = jax .device_count ()
191+ num_videos = num_devices * config .per_device_batch_size
192+ if num_videos > 0 :
193+ generation_time_per_video = generation_time / num_videos
194+ writer .add_scalar ("inference/generation_time_per_video" , generation_time_per_video , global_step = 0 )
195+ print (f"generation time per video: { generation_time_per_video } " )
196+ else :
197+ max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
198+
179199
180200 s0 = time .perf_counter ()
181201 if config .enable_profiler :
182202 max_utils .activate_profiler (config )
183203 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
184204 max_utils .deactivate_profiler (config )
185- print ("generation time: " , (time .perf_counter () - s0 ))
205+ generation_time_with_profiler = time .perf_counter () - s0
206+ print ("generation_time_with_profiler: " , generation_time_with_profiler )
207+ if writer and jax .process_index () == 0 :
208+ writer .add_scalar ("inference/generation_time_with_profiler" , generation_time_with_profiler , global_step = 0 )
209+
186210 return saved_video_path
187211
188212
0 commit comments