Skip to content

Commit 82502da

Browse files
committed
verion 2
1 parent 8b1b427 commit 82502da

1 file changed

Lines changed: 9 additions & 14 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -334,12 +334,14 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
334334
eval_batch["timesteps"], eval_data_shardings["timesteps"]
335335
)
336336
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
337-
loss = metrics["scalar"]["learning/eval_loss"]
338-
timestep = int(eval_batch["timesteps"][0])
339-
jax.debug.print("timesteps in eval_step: {x}", x=timestep)
340-
if timestep not in eval_losses_by_timestep:
341-
eval_losses_by_timestep[timestep] = []
342-
eval_losses_by_timestep[timestep].append(loss)
337+
losses = metrics["scalar"]["learning/eval_loss"]
338+
timesteps = eval_batch["timesteps"]
339+
for t, l in zip(timesteps, losses):
340+
timestep = int(t)
341+
if timestep not in eval_losses_by_timestep:
342+
eval_losses_by_timestep[timestep] = []
343+
eval_losses_by_timestep[timestep].append(l)
344+
print(f"timesteps: {timestep}, losses: {l}")
343345
except StopIteration:
344346
# This block is executed when the iterator has no more data
345347
break
@@ -433,13 +435,6 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
433435
# This ensures the batch size is consistent, though it might be redundant
434436
# if the evaluation dataloader is already configured correctly.
435437
jax.debug.print("timesteps before clip: {x}", x=data["timesteps"])
436-
for k, v in data.items():
437-
if k != "timesteps":
438-
data[k] = v[: config.global_batch_size_to_train_on, :]
439-
else:
440-
data[k] = v[: config.global_batch_size_to_train_on]
441-
jax.debug.print("timesteps after clip: {x}", x=data["timesteps"])
442-
443438
# The loss function logic is identical to training. We are evaluating the model's
444439
# ability to perform its core training objective (e.g., denoising).
445440
def loss_fn(params):
@@ -467,7 +462,7 @@ def loss_fn(params):
467462
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
468463
loss = (training_target - model_pred) ** 2
469464
loss = loss * training_weight
470-
loss = jnp.mean(loss)
465+
loss = loss.reshape(loss.shape[0], -1).mean(axis=1)
471466

472467
return loss
473468

0 commit comments

Comments
 (0)