Skip to content

Commit b6fa61e

Browse files
committed
keep eval the same for each evaluation
1 parent 95afb77 commit b6fa61e

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
260260
)
261261

262262
rng = jax.random.key(self.config.seed)
263+
rng, eval_rng_key = jax.random.split(rng)
263264
start_step = 0
264265
last_step_completion = datetime.datetime.now()
265266
local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None
@@ -305,7 +306,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
305306
# Re-create the iterator each time you start evaluation to reset it
306307
# This assumes your data loading logic can be called to get a fresh iterator.
307308
eval_data_iterator = self.load_dataset(mesh, is_training=False)
308-
eval_rng = jax.random.key(self.config.seed + step)
309+
eval_rng = eval_rng_key
309310
eval_metrics = []
310311
# Loop indefinitely until the iterator is exhausted
311312
while True:
@@ -394,7 +395,8 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
394395
"""
395396
Computes the evaluation loss for a single batch without updating model weights.
396397
"""
397-
_, new_rng, timestep_rng = jax.random.split(rng, num=3)
398+
# These values are fixed for the evaluation dataset as the initial rng for each evluation is the same
399+
noise_rng, timestep_rng, new_rng = jax.random.split(rng, num=3)
398400

399401
# This ensures the batch size is consistent, though it might be redundant
400402
# if the evaluation dataloader is already configured correctly.
@@ -419,14 +421,15 @@ def loss_fn(params):
419421
0,
420422
scheduler.config.num_train_timesteps,
421423
)
422-
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
424+
noise = jax.random.normal(key=noise_rng, shape=latents.shape, dtype=latents.dtype)
423425
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
424426

425427
# Get the model's prediction
426428
model_pred = model(
427429
hidden_states=noisy_latents,
428430
timestep=timesteps,
429431
encoder_hidden_states=encoder_hidden_states,
432+
deterministic=True,
430433
)
431434

432435
# Calculate the loss against the target

0 commit comments

Comments
 (0)