@@ -328,6 +328,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
328328 while True :
329329 try :
330330 with mesh :
331+ eval_start_time = datetime .datetime .now ()
331332 eval_batch = load_next_batch (eval_data_iterator , None , self .config )
332333 metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
333334 losses = metrics ["scalar" ]["learning/eval_loss" ]
@@ -336,11 +337,15 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
336337 gathered_timesteps = jax .device_get (gathered_timesteps_on_device )
337338 gathered_losses_on_device = multihost_utils .process_allgather (losses )
338339 gathered_losses = jax .device_get (gathered_losses_on_device )
339- for t , l in zip (gathered_timesteps , gathered_losses ):
340+ for t , l in zip (gathered_timesteps . flatten () , gathered_losses . flatten () ):
340341 timestep = int (t )
341342 if timestep not in eval_losses_by_timestep :
342343 eval_losses_by_timestep [timestep ] = []
343344 eval_losses_by_timestep [timestep ].append (l )
345+ eval_end_time = datetime .datetime .now ()
346+ eval_duration = eval_end_time - eval_start_time
347+ if jax .process_index () == 0 :
348+ max_logging .log (f" Eval step time { eval_duration .total_seconds ():.2f} seconds." )
344349 except StopIteration :
345350 # This block is executed when the iterator has no more data
346351 break
0 commit comments