@@ -158,7 +158,8 @@ def get_data_shardings(self, mesh):
158158
159159 def get_eval_data_shardings (self , mesh ):
160160 data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding ))
161- data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding , "timesteps" : None }
161+ timesteps_sharding = jax .sharding .NamedSharding (mesh , P ('data' ))
162+ data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding , "timesteps" : timesteps_sharding }
162163 return data_sharding
163164
164165 def load_dataset (self , mesh , is_training = True ):
@@ -196,7 +197,7 @@ def prepare_sample_eval(features):
196197 latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
197198 encoder_hidden_states = tf .io .parse_tensor (features ["encoder_hidden_states" ], out_type = tf .float32 )
198199 timesteps = features ["timesteps" ]
199- print (f "timesteps in prepare_sample_eval: { timesteps } " )
200+ tf . print ("timesteps in prepare_sample_eval:" , timesteps )
200201 return {"latents" : latents , "encoder_hidden_states" : encoder_hidden_states , "timesteps" : timesteps }
201202
202203 data_iterator = make_data_iterator (
@@ -332,9 +333,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
332333 try :
333334 with mesh :
334335 eval_batch = load_next_batch (eval_data_iterator , None , self .config )
336+ eval_batch ["timesteps" ] = jax .device_put (
337+ eval_batch ["timesteps" ], eval_data_shardings ["timesteps" ]
338+ )
335339 metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
336340 loss = metrics ["scalar" ]["learning/eval_loss" ]
337341 timestep = int (eval_batch ["timesteps" ][0 ])
342+ jax .debug .print ("timesteps in eval_step: {x}" , x = timestep )
338343 if timestep not in eval_losses_by_timestep :
339344 eval_losses_by_timestep [timestep ] = []
340345 eval_losses_by_timestep [timestep ].append (loss )
@@ -349,7 +354,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
349354 losses = jnp .array (losses )
350355 losses = losses [: min (60 , len (losses ))]
351356 mean_loss = jnp .mean (losses )
352- max_logging .log (f" Mean eval loss for timestep { timestep } : { mean_loss :.4f} " )
357+ max_logging .log (f" Mean eval loss for timestep { timestep } : { mean_loss :.4f} , num of losses: { len ( losses ) } " )
353358 mean_per_timestep .append (mean_loss )
354359 final_eval_loss = jnp .mean (jnp .array (mean_per_timestep ))
355360 max_logging .log (f"Step { step } , Final Average Eval loss: { final_eval_loss :.4f} " )
@@ -430,11 +435,13 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
430435
431436 # This ensures the batch size is consistent, though it might be redundant
432437 # if the evaluation dataloader is already configured correctly.
438+ jax .debug .print ("timesteps before clip: {x}" , x = data ["timesteps" ])
433439 for k , v in data .items ():
434440 if k != "timesteps" :
435441 data [k ] = v [: config .global_batch_size_to_train_on , :]
436442 else :
437443 data [k ] = v [: config .global_batch_size_to_train_on ]
444+ jax .debug .print ("timesteps after clip: {x}" , x = data ["timesteps" ])
438445
439446 # The loss function logic is identical to training. We are evaluating the model's
440447 # ability to perform its core training objective (e.g., denoising).
0 commit comments