Skip to content

Commit aaaa094

Browse files
committed
fix for loop bugs on timesteps and losses
1 parent a113271 commit aaaa094

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from skimage.metrics import structural_similarity as ssim
3939
from flax.training import train_state
4040
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
41+
from jax.experimental import multihost_utils
4142

4243

4344
class TrainState(train_state.TrainState):
@@ -331,7 +332,11 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
331332
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
332333
losses = metrics["scalar"]["learning/eval_loss"]
333334
timesteps = eval_batch["timesteps"]
334-
for t, l in zip(timesteps, losses):
335+
gathered_timesteps_on_device = multihost_utils.process_allgather(timesteps)
336+
gathered_timesteps = jax.device_get(gathered_timesteps_on_device)
337+
gathered_losses_on_device = multihost_utils.process_allgather(losses)
338+
gathered_losses = jax.device_get(gathered_losses_on_device)
339+
for t, l in zip(gathered_timesteps, gathered_losses):
335340
timestep = int(t)
336341
if timestep not in eval_losses_by_timestep:
337342
eval_losses_by_timestep[timestep] = []

0 commit comments

Comments
 (0)