Skip to content

Commit c3f53b1

Browse files
committed
solve comment
1 parent e6f495e commit c3f53b1

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
465465
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
466466
loss = (training_target - model_pred) ** 2
467467
loss = loss * training_weight
468+
# Calculate the mean loss per sample across all non-batch dimensions.
468469
loss = loss.reshape(loss.shape[0], -1).mean(axis=1)
469470

470471
return loss

0 commit comments

Comments
 (0)