Skip to content

Commit 9e74521

Browse files
committed
add eval time
1 parent 9b4ae33 commit 9e74521

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
328328
while True:
329329
try:
330330
with mesh:
331+
eval_start_time = datetime.datetime.now()
331332
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
332333
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
333334
losses = metrics["scalar"]["learning/eval_loss"]
@@ -336,11 +337,15 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
336337
gathered_timesteps = jax.device_get(gathered_timesteps_on_device)
337338
gathered_losses_on_device = multihost_utils.process_allgather(losses)
338339
gathered_losses = jax.device_get(gathered_losses_on_device)
339-
for t, l in zip(gathered_timesteps, gathered_losses):
340+
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):
340341
timestep = int(t)
341342
if timestep not in eval_losses_by_timestep:
342343
eval_losses_by_timestep[timestep] = []
343344
eval_losses_by_timestep[timestep].append(l)
345+
eval_end_time = datetime.datetime.now()
346+
eval_duration = eval_end_time - eval_start_time
347+
if jax.process_index() == 0:
348+
max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.")
344349
except StopIteration:
345350
# This block is executed when the iterator has no more data
346351
break

0 commit comments

Comments
 (0)