Skip to content

Commit ad8f9ba

Browse files
committed
successfully run on multi-host
1 parent 8e2bddb commit ad8f9ba

2 files changed

Lines changed: 13 additions & 14 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,4 +319,3 @@ eval_every: -1
319319
eval_data_dir: ""
320320
enable_generate_video_for_eval: False # This will increase the used TPU memory.
321321
eval_max_number_of_samples_in_bucket: 60
322-
eval_max_processed_batch_size: 8 # This is the max batch size per device for eval step. If the global eval batch size is larger than this, the eval step will be run multiple times.

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -336,37 +336,37 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
336336
metrics["scalar"]["learning/eval_loss"].block_until_ready()
337337
losses = metrics["scalar"]["learning/eval_loss"]
338338
timesteps = eval_batch["timesteps"]
339-
for t, l in zip(timesteps.flatten(), losses.flatten()):
340-
timestep = int(t)
341-
if timestep not in eval_losses_by_timestep:
342-
eval_losses_by_timestep[timestep] = []
343-
eval_losses_by_timestep[timestep].append(l)
339+
gathered_losses = multihost_utils.process_allgather(losses)
340+
gathered_losses = jax.device_get(gathered_losses)
341+
gathered_timesteps = multihost_utils.process_allgather(timesteps)
342+
gathered_timesteps = jax.device_get(gathered_timesteps)
344343
if jax.process_index() == 0:
344+
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):
345+
timestep = int(t)
346+
if timestep not in eval_losses_by_timestep:
347+
eval_losses_by_timestep[timestep] = []
348+
eval_losses_by_timestep[timestep].append(l)
345349
eval_end_time = datetime.datetime.now()
346350
eval_duration = eval_end_time - eval_start_time
347351
max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.")
348352
except StopIteration:
349353
# This block is executed when the iterator has no more data
350354
break
351355
# Check if any evaluation was actually performed
352-
if eval_losses_by_timestep:
356+
if eval_losses_by_timestep and jax.process_index() == 0:
353357
mean_per_timestep = []
354358
if jax.process_index() == 0:
355359
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
356360
for timestep, losses in sorted(eval_losses_by_timestep.items()):
357361
losses = jnp.array(losses)
358362
losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))]
359363
mean_loss = jnp.mean(losses)
360-
if jax.process_index() == 0:
361-
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}")
364+
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}")
362365
mean_per_timestep.append(mean_loss)
363366
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
364-
if jax.process_index() == 0:
365-
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
367+
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
366368
if writer:
367369
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
368-
else:
369-
max_logging.log(f"Step {step}, evaluation dataset was empty.")
370370
example_batch = next_batch_future.result()
371371
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
372372
max_logging.log(f"Saving checkpoint for step {step}")
@@ -468,7 +468,7 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
468468
# Directly compute the loss without calculating gradients.
469469
# The model's state.params are used but not updated.
470470
bs = len(data["latents"])
471-
single_batch_size = min(config.eval_max_processed_batch_size, config.global_batch_size_to_train_on)
471+
single_batch_size = config.global_batch_size_to_train_on
472472
losses = jnp.zeros(bs)
473473
for i in range(0, bs, single_batch_size):
474474
start = i

0 commit comments

Comments
 (0)