@@ -334,12 +334,14 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
334334 eval_batch ["timesteps" ], eval_data_shardings ["timesteps" ]
335335 )
336336 metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
337- loss = metrics ["scalar" ]["learning/eval_loss" ]
338- timestep = int (eval_batch ["timesteps" ][0 ])
339- jax .debug .print ("timesteps in eval_step: {x}" , x = timestep )
340- if timestep not in eval_losses_by_timestep :
341- eval_losses_by_timestep [timestep ] = []
342- eval_losses_by_timestep [timestep ].append (loss )
337+ losses = metrics ["scalar" ]["learning/eval_loss" ]
338+ timesteps = eval_batch ["timesteps" ]
339+ for t , l in zip (timesteps , losses ):
340+ timestep = int (t )
341+ if timestep not in eval_losses_by_timestep :
342+ eval_losses_by_timestep [timestep ] = []
343+ eval_losses_by_timestep [timestep ].append (l )
344+ print (f"timesteps: { timestep } , losses: { l } " )
343345 except StopIteration :
344346 # This block is executed when the iterator has no more data
345347 break
@@ -433,13 +435,6 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
433435 # This ensures the batch size is consistent, though it might be redundant
434436 # if the evaluation dataloader is already configured correctly.
435437 jax .debug .print ("timesteps before clip: {x}" , x = data ["timesteps" ])
436- for k , v in data .items ():
437- if k != "timesteps" :
438- data [k ] = v [: config .global_batch_size_to_train_on , :]
439- else :
440- data [k ] = v [: config .global_batch_size_to_train_on ]
441- jax .debug .print ("timesteps after clip: {x}" , x = data ["timesteps" ])
442-
443438 # The loss function logic is identical to training. We are evaluating the model's
444439 # ability to perform its core training objective (e.g., denoising).
445440 def loss_fn (params ):
@@ -467,7 +462,7 @@ def loss_fn(params):
467462 training_weight = jnp .expand_dims (scheduler .training_weight (scheduler_state , timesteps ), axis = (1 , 2 , 3 , 4 ))
468463 loss = (training_target - model_pred ) ** 2
469464 loss = loss * training_weight
470- loss = jnp . mean (loss )
465+ loss = loss . reshape (loss . shape [ 0 ], - 1 ). mean ( axis = 1 )
471466
472467 return loss
473468
0 commit comments