Skip to content

Commit fb1c00b

Browse files
authored
fix eval on g3 (#266)
1 parent 158e1f2 commit fb1c00b

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr
250250
metrics["scalar"]["learning/eval_loss"].block_until_ready()
251251
losses = metrics["scalar"]["learning/eval_loss"]
252252
timesteps = eval_batch["timesteps"]
253-
gathered_losses = multihost_utils.process_allgather(losses)
253+
gathered_losses = multihost_utils.process_allgather(losses, tiled=True)
254254
gathered_losses = jax.device_get(gathered_losses)
255-
gathered_timesteps = multihost_utils.process_allgather(timesteps)
255+
gathered_timesteps = multihost_utils.process_allgather(timesteps, tiled=True)
256256
gathered_timesteps = jax.device_get(gathered_timesteps)
257257
if jax.process_index() == 0:
258258
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):

0 commit comments

Comments
 (0)