We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e6f495e commit c3f53b1Copy full SHA for c3f53b1
1 file changed
src/maxdiffusion/trainers/wan_trainer.py
@@ -465,6 +465,7 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
465
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
466
loss = (training_target - model_pred) ** 2
467
loss = loss * training_weight
468
+ # Calculate the mean loss per sample across all non-batch dimensions.
469
loss = loss.reshape(loss.shape[0], -1).mean(axis=1)
470
471
return loss
0 commit comments