@@ -158,8 +158,7 @@ def get_data_shardings(self, mesh):
158158
159159 def get_eval_data_shardings (self , mesh ):
160160 data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding ))
161- timesteps_sharding = jax .sharding .NamedSharding (mesh , P ('data' ))
162- data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding , "timesteps" : timesteps_sharding }
161+ data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding , "timesteps" : data_sharding }
163162 return data_sharding
164163
165164 def load_dataset (self , mesh , is_training = True ):
@@ -194,7 +193,6 @@ def prepare_sample_eval(features):
194193 latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
195194 encoder_hidden_states = tf .io .parse_tensor (features ["encoder_hidden_states" ], out_type = tf .float32 )
196195 timesteps = features ["timesteps" ]
197- tf .print ("timesteps in prepare_sample_eval:" , timesteps )
198196 return {"latents" : latents , "encoder_hidden_states" : encoder_hidden_states , "timesteps" : timesteps }
199197
200198 data_iterator = make_data_iterator (
@@ -330,9 +328,6 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
330328 try :
331329 with mesh :
332330 eval_batch = load_next_batch (eval_data_iterator , None , self .config )
333- eval_batch ["timesteps" ] = jax .device_put (
334- eval_batch ["timesteps" ], eval_data_shardings ["timesteps" ]
335- )
336331 metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
337332 losses = metrics ["scalar" ]["learning/eval_loss" ]
338333 timesteps = eval_batch ["timesteps" ]
@@ -432,9 +427,6 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
432427 """
433428 _ , new_rng = jax .random .split (rng , num = 2 )
434429
435- # This ensures the batch size is consistent, though it might be redundant
436- # if the evaluation dataloader is already configured correctly.
437- jax .debug .print ("timesteps before clip: {x}" , x = data ["timesteps" ])
438430 # The loss function logic is identical to training. We are evaluating the model's
439431 # ability to perform its core training objective (e.g., denoising).
440432 def loss_fn (params ):
0 commit comments