@@ -327,25 +327,25 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
327327 # Loop indefinitely until the iterator is exhausted
328328 while True :
329329 try :
330- with mesh :
331- eval_start_time = datetime .datetime .now ()
332- eval_batch = load_next_batch (eval_data_iterator , None , self .config )
330+ eval_start_time = datetime .datetime .now ()
331+ eval_batch = load_next_batch (eval_data_iterator , None , self .config )
332+ with pipeline .mesh , nn_partitioning .axis_rules (
333+ self .config .logical_axis_rules
334+ ):
333335 metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
334- losses = metrics ["scalar" ]["learning/eval_loss" ]
335- timesteps = eval_batch ["timesteps" ]
336- gathered_timesteps_on_device = multihost_utils .process_allgather (timesteps )
337- gathered_timesteps = jax .device_get (gathered_timesteps_on_device )
338- gathered_losses_on_device = multihost_utils .process_allgather (losses )
339- gathered_losses = jax .device_get (gathered_losses_on_device )
340- for t , l in zip (gathered_timesteps .flatten (), gathered_losses .flatten ()):
341- timestep = int (t )
342- if timestep not in eval_losses_by_timestep :
343- eval_losses_by_timestep [timestep ] = []
344- eval_losses_by_timestep [timestep ].append (l )
336+ losses = metrics ["scalar" ]["learning/eval_loss" ]
337+ timesteps = eval_batch ["timesteps" ]
338+ gathered_losses_on_device = multihost_utils .process_allgather (losses )
339+ gathered_losses = jax .device_get (gathered_losses_on_device )
340+ for t , l in zip (timesteps .flatten (), losses .flatten ()):
341+ timestep = int (t )
342+ if timestep not in eval_losses_by_timestep :
343+ eval_losses_by_timestep [timestep ] = []
344+ eval_losses_by_timestep [timestep ].append (l )
345+ if jax .process_index () == 0 :
345346 eval_end_time = datetime .datetime .now ()
346347 eval_duration = eval_end_time - eval_start_time
347- if jax .process_index () == 0 :
348- max_logging .log (f" Eval step time { eval_duration .total_seconds ():.2f} seconds." )
348+ max_logging .log (f" Eval step time { eval_duration .total_seconds ():.2f} seconds." )
349349 except StopIteration :
350350 # This block is executed when the iterator has no more data
351351 break
0 commit comments