Skip to content

Commit c0fa2ca

Browse files
committed
refactor
1 parent ad8f9ba commit c0fa2ca

2 files changed

Lines changed: 58 additions & 50 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,5 @@ eval_every: -1
319319
eval_data_dir: ""
320320
enable_generate_video_for_eval: False # This will increase the used TPU memory.
321321
eval_max_number_of_samples_in_bucket: 60
322+
323+
enable_ssim: True

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,9 @@ def prepare_sample_eval(features):
211211
def start_training(self):
212212

213213
pipeline = self.load_checkpoint()
214-
# Generate a sample before training to compare against generated sample after training.
215-
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
214+
if self.config.enable_ssim:
215+
# Generate a sample before training to compare against generated sample after training.
216+
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
216217

217218
if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
218219
# save some memory.
@@ -230,8 +231,57 @@ def start_training(self):
230231
# Returns pipeline with trained transformer state
231232
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
232233

233-
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
234-
print_ssim(pretrained_video_path, posttrained_video_path)
234+
if self.config.enable_ssim:
235+
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
236+
print_ssim(pretrained_video_path, posttrained_video_path)
237+
238+
def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer):
239+
eval_data_iterator = self.load_dataset(mesh, is_training=False)
240+
eval_rng = eval_rng_key
241+
eval_losses_by_timestep = {}
242+
# Loop indefinitely until the iterator is exhausted
243+
while True:
244+
try:
245+
eval_start_time = datetime.datetime.now()
246+
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
247+
with mesh, nn_partitioning.axis_rules(
248+
self.config.logical_axis_rules
249+
):
250+
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
251+
metrics["scalar"]["learning/eval_loss"].block_until_ready()
252+
losses = metrics["scalar"]["learning/eval_loss"]
253+
timesteps = eval_batch["timesteps"]
254+
gathered_losses = multihost_utils.process_allgather(losses)
255+
gathered_losses = jax.device_get(gathered_losses)
256+
gathered_timesteps = multihost_utils.process_allgather(timesteps)
257+
gathered_timesteps = jax.device_get(gathered_timesteps)
258+
if jax.process_index() == 0:
259+
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):
260+
timestep = int(t)
261+
if timestep not in eval_losses_by_timestep:
262+
eval_losses_by_timestep[timestep] = []
263+
eval_losses_by_timestep[timestep].append(l)
264+
eval_end_time = datetime.datetime.now()
265+
eval_duration = eval_end_time - eval_start_time
266+
max_logging.log(f"Eval time: {eval_duration.total_seconds():.2f} seconds.")
267+
except StopIteration:
268+
# This block is executed when the iterator has no more data
269+
break
270+
# Check if any evaluation was actually performed
271+
if eval_losses_by_timestep and jax.process_index() == 0:
272+
mean_per_timestep = []
273+
if jax.process_index() == 0:
274+
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
275+
for timestep, losses in sorted(eval_losses_by_timestep.items()):
276+
losses = jnp.array(losses)
277+
losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))]
278+
mean_loss = jnp.mean(losses)
279+
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}")
280+
mean_per_timestep.append(mean_loss)
281+
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
282+
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
283+
if writer:
284+
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
235285

236286
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
237287
mesh = pipeline.mesh
@@ -321,52 +371,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
321371
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
322372
# Re-create the iterator each time you start evaluation to reset it
323373
# This assumes your data loading logic can be called to get a fresh iterator.
324-
eval_data_iterator = self.load_dataset(mesh, is_training=False)
325-
eval_rng = eval_rng_key
326-
eval_losses_by_timestep = {}
327-
# Loop indefinitely until the iterator is exhausted
328-
while True:
329-
try:
330-
eval_start_time = datetime.datetime.now()
331-
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
332-
with pipeline.mesh, nn_partitioning.axis_rules(
333-
self.config.logical_axis_rules
334-
):
335-
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
336-
metrics["scalar"]["learning/eval_loss"].block_until_ready()
337-
losses = metrics["scalar"]["learning/eval_loss"]
338-
timesteps = eval_batch["timesteps"]
339-
gathered_losses = multihost_utils.process_allgather(losses)
340-
gathered_losses = jax.device_get(gathered_losses)
341-
gathered_timesteps = multihost_utils.process_allgather(timesteps)
342-
gathered_timesteps = jax.device_get(gathered_timesteps)
343-
if jax.process_index() == 0:
344-
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):
345-
timestep = int(t)
346-
if timestep not in eval_losses_by_timestep:
347-
eval_losses_by_timestep[timestep] = []
348-
eval_losses_by_timestep[timestep].append(l)
349-
eval_end_time = datetime.datetime.now()
350-
eval_duration = eval_end_time - eval_start_time
351-
max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.")
352-
except StopIteration:
353-
# This block is executed when the iterator has no more data
354-
break
355-
# Check if any evaluation was actually performed
356-
if eval_losses_by_timestep and jax.process_index() == 0:
357-
mean_per_timestep = []
358-
if jax.process_index() == 0:
359-
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
360-
for timestep, losses in sorted(eval_losses_by_timestep.items()):
361-
losses = jnp.array(losses)
362-
losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))]
363-
mean_loss = jnp.mean(losses)
364-
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}")
365-
mean_per_timestep.append(mean_loss)
366-
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
367-
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
368-
if writer:
369-
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
374+
self.eval(mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer)
375+
370376
example_batch = next_batch_future.result()
371377
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
372378
max_logging.log(f"Saving checkpoint for step {step}")

0 commit comments

Comments
 (0)