@@ -342,15 +342,18 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
342342 # Check if any evaluation was actually performed
343343 if eval_losses_by_timestep :
344344 mean_per_timestep = []
345- max_logging .log (f"Step { step } , calculating mean loss per timestep..." )
345+ if jax .process_index () == 0 :
346+ max_logging .log (f"Step { step } , calculating mean loss per timestep..." )
346347 for timestep , losses in sorted (eval_losses_by_timestep .items ()):
347348 losses = jnp .array (losses )
348349 losses = losses [: min (self .config .eval_max_number_of_samples_in_bucket , len (losses ))]
349350 mean_loss = jnp .mean (losses )
350- max_logging .log (f" Mean eval loss for timestep { timestep } : { mean_loss :.4f} , num of losses: { len (losses )} " )
351+ if jax .process_index () == 0 :
352+ max_logging .log (f" Mean eval loss for timestep { timestep } : { mean_loss :.4f} , num of losses: { len (losses )} " )
351353 mean_per_timestep .append (mean_loss )
352354 final_eval_loss = jnp .mean (jnp .array (mean_per_timestep ))
353- max_logging .log (f"Step { step } , Final Average Eval loss: { final_eval_loss :.4f} " )
355+ if jax .process_index () == 0 :
356+ max_logging .log (f"Step { step } , Final Average Eval loss: { final_eval_loss :.4f} " )
354357 if writer :
355358 writer .add_scalar ("learning/eval_loss" , final_eval_loss , step )
356359 else :
@@ -428,14 +431,14 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
428431
429432 # The loss function logic is identical to training. We are evaluating the model's
430433 # ability to perform its core training objective (e.g., denoising).
431- def loss_fn (params ):
434+ def loss_fn (params , latents , encoder_hidden_states , timesteps ):
432435 # Reconstruct the model from its definition and parameters
433436 model = nnx .merge (state .graphdef , params , state .rest_of_state )
434437
435438 # Prepare inputs
436- latents = data ["latents" ].astype (config .weights_dtype )
437- encoder_hidden_states = data ["encoder_hidden_states" ].astype (config .weights_dtype )
438- timesteps = data ["timesteps" ].astype ("int64" )
439+ # latents = data["latents"].astype(config.weights_dtype)
440+ # encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
441+ # timesteps = data["timesteps"].astype("int64")
439442
440443 noise = jax .random .normal (key = new_rng , shape = latents .shape , dtype = latents .dtype )
441444 noisy_latents = scheduler .add_noise (scheduler_state , latents , noise , timesteps )
@@ -460,10 +463,22 @@ def loss_fn(params):
460463 # --- Key Difference from train_step ---
461464 # Directly compute the loss without calculating gradients.
462465 # The model's state.params are used but not updated.
463- loss = loss_fn (state .params )
466+ bs = len (data ["latents" ])
467+ single_batch_size = min (8 , config .global_batch_size_to_train_on )
468+ losses = jnp .zeros (bs )
469+ for i in range (0 , bs , single_batch_size ):
470+ start = i
471+ end = min (i + single_batch_size , bs )
472+ jax .debug .print ("Eval step processing samples {start} to {end}" , start = start , end = end )
473+ latents = data ["latents" ][start :end , :].astype (config .weights_dtype )
474+ encoder_hidden_states = data ["encoder_hidden_states" ][start :end , :].astype (config .weights_dtype )
475+ timesteps = data ["timesteps" ][start :end ].astype ("int64" )
476+ loss = loss_fn (state .params , latents , encoder_hidden_states , timesteps )
477+ losses = losses .at [start :end ].set (loss )
464478
465479 # Structure the metrics for logging and aggregation
466- metrics = {"scalar" : {"learning/eval_loss" : loss }}
480+ metrics = {"scalar" : {"learning/eval_loss" : losses }}
481+ jax .debug .print ("Eval step losses: {losses}" , losses = losses )
467482
468483 # Return the computed metrics and the new RNG key for the next eval step
469484 return metrics , new_rng
0 commit comments