diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 5c08a4060..1265223a7 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -250,9 +250,9 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr metrics["scalar"]["learning/eval_loss"].block_until_ready() losses = metrics["scalar"]["learning/eval_loss"] timesteps = eval_batch["timesteps"] - gathered_losses = multihost_utils.process_allgather(losses) + gathered_losses = multihost_utils.process_allgather(losses, tiled=True) gathered_losses = jax.device_get(gathered_losses) - gathered_timesteps = multihost_utils.process_allgather(timesteps) + gathered_timesteps = multihost_utils.process_allgather(timesteps, tiled=True) gathered_timesteps = jax.device_get(gathered_timesteps) if jax.process_index() == 0: for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):