Skip to content

Commit 8e2bddb

Browse files
committed
block until ready
1 parent 140db99 commit 8e2bddb

1 file changed

Lines changed: 1 addition & 2 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
333333
self.config.logical_axis_rules
334334
):
335335
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
336+
metrics["scalar"]["learning/eval_loss"].block_until_ready()
336337
losses = metrics["scalar"]["learning/eval_loss"]
337338
timesteps = eval_batch["timesteps"]
338-
gathered_losses_on_device = multihost_utils.process_allgather(losses)
339-
gathered_losses = jax.device_get(gathered_losses_on_device)
340339
for t, l in zip(timesteps.flatten(), losses.flatten()):
341340
timestep = int(t)
342341
if timestep not in eval_losses_by_timestep:

0 commit comments

Comments
 (0)