@@ -211,8 +211,9 @@ def prepare_sample_eval(features):
211211 def start_training (self ):
212212
213213 pipeline = self .load_checkpoint ()
214- # Generate a sample before training to compare against generated sample after training.
215- pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
214+ if self .config .enable_ssim :
215+ # Generate a sample before training to compare against generated sample after training.
216+ pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
216217
217218 if self .config .eval_every == - 1 or (not self .config .enable_generate_video_for_eval ):
218219 # save some memory.
@@ -230,8 +231,57 @@ def start_training(self):
230231 # Returns pipeline with trained transformer state
231232 pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator )
232233
233- posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
234- print_ssim (pretrained_video_path , posttrained_video_path )
234+ if self .config .enable_ssim :
235+ posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
236+ print_ssim (pretrained_video_path , posttrained_video_path )
237+
238+ def eval (self , mesh , eval_rng_key , step , p_eval_step , state , scheduler_state , writer ):
239+ eval_data_iterator = self .load_dataset (mesh , is_training = False )
240+ eval_rng = eval_rng_key
241+ eval_losses_by_timestep = {}
242+ # Loop indefinitely until the iterator is exhausted
243+ while True :
244+ try :
245+ eval_start_time = datetime .datetime .now ()
246+ eval_batch = load_next_batch (eval_data_iterator , None , self .config )
247+ with mesh , nn_partitioning .axis_rules (
248+ self .config .logical_axis_rules
249+ ):
250+ metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
251+ metrics ["scalar" ]["learning/eval_loss" ].block_until_ready ()
252+ losses = metrics ["scalar" ]["learning/eval_loss" ]
253+ timesteps = eval_batch ["timesteps" ]
254+ gathered_losses = multihost_utils .process_allgather (losses )
255+ gathered_losses = jax .device_get (gathered_losses )
256+ gathered_timesteps = multihost_utils .process_allgather (timesteps )
257+ gathered_timesteps = jax .device_get (gathered_timesteps )
258+ if jax .process_index () == 0 :
259+ for t , l in zip (gathered_timesteps .flatten (), gathered_losses .flatten ()):
260+ timestep = int (t )
261+ if timestep not in eval_losses_by_timestep :
262+ eval_losses_by_timestep [timestep ] = []
263+ eval_losses_by_timestep [timestep ].append (l )
264+ eval_end_time = datetime .datetime .now ()
265+ eval_duration = eval_end_time - eval_start_time
266+ max_logging .log (f"Eval time: { eval_duration .total_seconds ():.2f} seconds." )
267+ except StopIteration :
268+ # This block is executed when the iterator has no more data
269+ break
270+ # Check if any evaluation was actually performed
271+ if eval_losses_by_timestep and jax .process_index () == 0 :
272+ mean_per_timestep = []
273+ if jax .process_index () == 0 :
274+ max_logging .log (f"Step { step } , calculating mean loss per timestep..." )
275+ for timestep , losses in sorted (eval_losses_by_timestep .items ()):
276+ losses = jnp .array (losses )
277+ losses = losses [: min (self .config .eval_max_number_of_samples_in_bucket , len (losses ))]
278+ mean_loss = jnp .mean (losses )
279+ max_logging .log (f" Mean eval loss for timestep { timestep } : { mean_loss :.4f} " )
280+ mean_per_timestep .append (mean_loss )
281+ final_eval_loss = jnp .mean (jnp .array (mean_per_timestep ))
282+ max_logging .log (f"Step { step } , Final Average Eval loss: { final_eval_loss :.4f} " )
283+ if writer :
284+ writer .add_scalar ("learning/eval_loss" , final_eval_loss , step )
235285
236286 def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator ):
237287 mesh = pipeline .mesh
@@ -321,52 +371,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
321371 inference_generate_video (self .config , pipeline , filename_prefix = f"{ step + 1 } -train_steps-" )
322372 # Re-create the iterator each time you start evaluation to reset it
323373 # This assumes your data loading logic can be called to get a fresh iterator.
324- eval_data_iterator = self .load_dataset (mesh , is_training = False )
325- eval_rng = eval_rng_key
326- eval_losses_by_timestep = {}
327- # Loop indefinitely until the iterator is exhausted
328- while True :
329- try :
330- eval_start_time = datetime .datetime .now ()
331- eval_batch = load_next_batch (eval_data_iterator , None , self .config )
332- with pipeline .mesh , nn_partitioning .axis_rules (
333- self .config .logical_axis_rules
334- ):
335- metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
336- metrics ["scalar" ]["learning/eval_loss" ].block_until_ready ()
337- losses = metrics ["scalar" ]["learning/eval_loss" ]
338- timesteps = eval_batch ["timesteps" ]
339- gathered_losses = multihost_utils .process_allgather (losses )
340- gathered_losses = jax .device_get (gathered_losses )
341- gathered_timesteps = multihost_utils .process_allgather (timesteps )
342- gathered_timesteps = jax .device_get (gathered_timesteps )
343- if jax .process_index () == 0 :
344- for t , l in zip (gathered_timesteps .flatten (), gathered_losses .flatten ()):
345- timestep = int (t )
346- if timestep not in eval_losses_by_timestep :
347- eval_losses_by_timestep [timestep ] = []
348- eval_losses_by_timestep [timestep ].append (l )
349- eval_end_time = datetime .datetime .now ()
350- eval_duration = eval_end_time - eval_start_time
351- max_logging .log (f" Eval step time { eval_duration .total_seconds ():.2f} seconds." )
352- except StopIteration :
353- # This block is executed when the iterator has no more data
354- break
355- # Check if any evaluation was actually performed
356- if eval_losses_by_timestep and jax .process_index () == 0 :
357- mean_per_timestep = []
358- if jax .process_index () == 0 :
359- max_logging .log (f"Step { step } , calculating mean loss per timestep..." )
360- for timestep , losses in sorted (eval_losses_by_timestep .items ()):
361- losses = jnp .array (losses )
362- losses = losses [: min (self .config .eval_max_number_of_samples_in_bucket , len (losses ))]
363- mean_loss = jnp .mean (losses )
364- max_logging .log (f" Mean eval loss for timestep { timestep } : { mean_loss :.4f} " )
365- mean_per_timestep .append (mean_loss )
366- final_eval_loss = jnp .mean (jnp .array (mean_per_timestep ))
367- max_logging .log (f"Step { step } , Final Average Eval loss: { final_eval_loss :.4f} " )
368- if writer :
369- writer .add_scalar ("learning/eval_loss" , final_eval_loss , step )
374+ self .eval (mesh , eval_rng_key , step , p_eval_step , state , scheduler_state , writer )
375+
370376 example_batch = next_batch_future .result ()
371377 if step != 0 and self .config .checkpoint_every != - 1 and step % self .config .checkpoint_every == 0 :
372378 max_logging .log (f"Saving checkpoint for step { step } " )
0 commit comments