@@ -260,6 +260,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
260260 )
261261
262262 rng = jax .random .key (self .config .seed )
263+ rng , eval_rng_key = jax .random .split (rng )
263264 start_step = 0
264265 last_step_completion = datetime .datetime .now ()
265266 local_metrics_file = open (self .config .metrics_file , "a" , encoding = "utf8" ) if self .config .metrics_file else None
@@ -305,7 +306,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
305306 # Re-create the iterator each time you start evaluation to reset it
306307 # This assumes your data loading logic can be called to get a fresh iterator.
307308 eval_data_iterator = self .load_dataset (mesh , is_training = False )
308- eval_rng = jax . random . key ( self . config . seed + step )
309+ eval_rng = eval_rng_key
309310 eval_metrics = []
310311 # Loop indefinitely until the iterator is exhausted
311312 while True :
@@ -394,7 +395,8 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
394395 """
395396 Computes the evaluation loss for a single batch without updating model weights.
396397 """
397- _ , new_rng , timestep_rng = jax .random .split (rng , num = 3 )
398+ # These values are fixed for the evaluation dataset as the initial rng for each evluation is the same
399+ noise_rng , timestep_rng , new_rng = jax .random .split (rng , num = 3 )
398400
399401 # This ensures the batch size is consistent, though it might be redundant
400402 # if the evaluation dataloader is already configured correctly.
@@ -419,14 +421,15 @@ def loss_fn(params):
419421 0 ,
420422 scheduler .config .num_train_timesteps ,
421423 )
422- noise = jax .random .normal (key = new_rng , shape = latents .shape , dtype = latents .dtype )
424+ noise = jax .random .normal (key = noise_rng , shape = latents .shape , dtype = latents .dtype )
423425 noisy_latents = scheduler .add_noise (scheduler_state , latents , noise , timesteps )
424426
425427 # Get the model's prediction
426428 model_pred = model (
427429 hidden_states = noisy_latents ,
428430 timestep = timesteps ,
429431 encoder_hidden_states = encoder_hidden_states ,
432+ deterministic = True ,
430433 )
431434
432435 # Calculate the loss against the target
0 commit comments