Skip to content

Commit 140db99

Browse files
committed
fix eval slow bug
1 parent 9e74521 commit 140db99

1 file changed

Lines changed: 16 additions & 16 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -327,25 +327,25 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
327327
# Loop indefinitely until the iterator is exhausted
328328
while True:
329329
try:
330-
with mesh:
331-
eval_start_time = datetime.datetime.now()
332-
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
330+
eval_start_time = datetime.datetime.now()
331+
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
332+
with pipeline.mesh, nn_partitioning.axis_rules(
333+
self.config.logical_axis_rules
334+
):
333335
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
334-
losses = metrics["scalar"]["learning/eval_loss"]
335-
timesteps = eval_batch["timesteps"]
336-
gathered_timesteps_on_device = multihost_utils.process_allgather(timesteps)
337-
gathered_timesteps = jax.device_get(gathered_timesteps_on_device)
338-
gathered_losses_on_device = multihost_utils.process_allgather(losses)
339-
gathered_losses = jax.device_get(gathered_losses_on_device)
340-
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):
341-
timestep = int(t)
342-
if timestep not in eval_losses_by_timestep:
343-
eval_losses_by_timestep[timestep] = []
344-
eval_losses_by_timestep[timestep].append(l)
336+
losses = metrics["scalar"]["learning/eval_loss"]
337+
timesteps = eval_batch["timesteps"]
338+
gathered_losses_on_device = multihost_utils.process_allgather(losses)
339+
gathered_losses = jax.device_get(gathered_losses_on_device)
340+
for t, l in zip(timesteps.flatten(), losses.flatten()):
341+
timestep = int(t)
342+
if timestep not in eval_losses_by_timestep:
343+
eval_losses_by_timestep[timestep] = []
344+
eval_losses_by_timestep[timestep].append(l)
345+
if jax.process_index() == 0:
345346
eval_end_time = datetime.datetime.now()
346347
eval_duration = eval_end_time - eval_start_time
347-
if jax.process_index() == 0:
348-
max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.")
348+
max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.")
349349
except StopIteration:
350350
# This block is executed when the iterator has no more data
351351
break

0 commit comments

Comments
 (0)