@@ -139,9 +139,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
139139
140140
141141def run (config , pipeline = None , filename_prefix = "" ):
142- print ("seed: " , config .seed )
143142 model_key = config .model_name
144- tensorboard_dir = os .path .join (config .output_dir , "tensorboard" )
145143 # Initialize TensorBoard writer
146144 writer = max_utils .initialize_summary_writer (config )
147145 if jax .process_index () == 0 and writer :
@@ -168,17 +166,17 @@ def run(config, pipeline=None, filename_prefix=""):
168166 )
169167
170168 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
171- print ("===================== Model details =======================" )
172- print ("model name: " , config .model_name )
173- print ("model path: " , config .pretrained_model_name_or_path )
174- print ("model type: t2v" )
175- print ("hardware: " , jax .devices ()[0 ].platform )
176- print ("number of devices: " , jax .device_count ())
177- print ("per_device_batch_size: " , config .per_device_batch_size )
178- print ("============================================================" )
169+ max_logging . log ("===================== Model details =======================" )
170+ max_logging . log ("model name: " , config .model_name )
171+ max_logging . log ("model path: " , config .pretrained_model_name_or_path )
172+ max_logging . log ("model type: t2v" )
173+ max_logging . log ("hardware: " , jax .devices ()[0 ].platform )
174+ max_logging . log ("number of devices: " , jax .device_count ())
175+ max_logging . log ("per_device_batch_size: " , config .per_device_batch_size )
176+ max_logging . log ("============================================================" )
179177
180178 compile_time = time .perf_counter () - s0
181- print ("compile_time: " , compile_time )
179+ max_logging . log ("compile_time: " , compile_time )
182180 if writer and jax .process_index () == 0 :
183181 writer .add_scalar ("inference/compile_time" , compile_time , global_step = 0 )
184182 saved_video_path = []
@@ -192,15 +190,15 @@ def run(config, pipeline=None, filename_prefix=""):
192190 s0 = time .perf_counter ()
193191 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
194192 generation_time = time .perf_counter () - s0
195- print ("generation_time: " , generation_time )
193+ max_logging . log ("generation_time: " , generation_time )
196194 if writer and jax .process_index () == 0 :
197195 writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
198196 num_devices = jax .device_count ()
199197 num_videos = num_devices * config .per_device_batch_size
200198 if num_videos > 0 :
201199 generation_time_per_video = generation_time / num_videos
202200 writer .add_scalar ("inference/generation_time_per_video" , generation_time_per_video , global_step = 0 )
203- print (f"generation time per video: { generation_time_per_video } " )
201+ max_logging . log (f"generation time per video: { generation_time_per_video } " )
204202 else :
205203 max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
206204
@@ -211,7 +209,7 @@ def run(config, pipeline=None, filename_prefix=""):
211209 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
212210 max_utils .deactivate_profiler (config )
213211 generation_time_with_profiler = time .perf_counter () - s0
214- print ("generation_time_with_profiler: " , generation_time_with_profiler )
212+ max_logging . log ("generation_time_with_profiler: " , generation_time_with_profiler )
215213 if writer and jax .process_index () == 0 :
216214 writer .add_scalar ("inference/generation_time_with_profiler" , generation_time_with_profiler , global_step = 0 )
217215
0 commit comments