@@ -156,6 +156,11 @@ def get_data_shardings(self, mesh):
156156 data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding }
157157 return data_sharding
158158
159+ def get_eval_data_shardings (self , mesh ):
160+ data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding ))
161+ data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding , "timesteps" : None }
162+ return data_sharding
163+
159164 def load_dataset (self , mesh , is_training = True ):
160165 # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
161166 # Image pre-training - txt2img 256px
@@ -170,25 +175,38 @@ def load_dataset(self, mesh, is_training=True):
170175 raise ValueError (
171176 "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True"
172177 )
173-
174- feature_description = {
178+
179+ feature_description_train = {
175180 "latents" : tf .io .FixedLenFeature ([], tf .string ),
176181 "encoder_hidden_states" : tf .io .FixedLenFeature ([], tf .string ),
177182 }
178183
179- def prepare_sample (features ):
184+ def prepare_sample_train (features ):
180185 latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
181186 encoder_hidden_states = tf .io .parse_tensor (features ["encoder_hidden_states" ], out_type = tf .float32 )
182187 return {"latents" : latents , "encoder_hidden_states" : encoder_hidden_states }
188+
189+ feature_description_eval = {
190+ "latents" : tf .io .FixedLenFeature ([], tf .string ),
191+ "encoder_hidden_states" : tf .io .FixedLenFeature ([], tf .string ),
192+ "timesteps" : tf .io .FixedLenFeature ([], tf .int64 ),
193+ }
194+
195+ def prepare_sample_eval (features ):
196+ latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
197+ encoder_hidden_states = tf .io .parse_tensor (features ["encoder_hidden_states" ], out_type = tf .float32 )
198+ timesteps = features ["timesteps" ]
199+ print (f"timesteps in prepare_sample_eval: { timesteps } " )
200+ return {"latents" : latents , "encoder_hidden_states" : encoder_hidden_states , "timesteps" : timesteps }
183201
184202 data_iterator = make_data_iterator (
185203 config ,
186204 jax .process_index (),
187205 jax .process_count (),
188206 mesh ,
189207 config .global_batch_size_to_load ,
190- feature_description = feature_description ,
191- prepare_sample_fn = prepare_sample ,
208+ feature_description = feature_description_train if is_training else feature_description_eval ,
209+ prepare_sample_fn = prepare_sample_train if is_training else prepare_sample_eval ,
192210 is_training = is_training ,
193211 )
194212 return data_iterator
@@ -197,7 +215,7 @@ def start_training(self):
197215
198216 pipeline = self .load_checkpoint ()
199217 # Generate a sample before training to compare against generated sample after training.
200- pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
218+ # pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
201219
202220 if self .config .eval_every == - 1 or (not self .config .enable_generate_video_for_eval ):
203221 # save some memory.
@@ -215,8 +233,8 @@ def start_training(self):
215233 # Returns pipeline with trained transformer state
216234 pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator )
217235
218- posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
219- print_ssim (pretrained_video_path , posttrained_video_path )
236+ # posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
237+ # print_ssim(pretrained_video_path, posttrained_video_path)
220238
221239 def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator ):
222240 mesh = pipeline .mesh
@@ -231,6 +249,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
231249 state = jax .lax .with_sharding_constraint (state , state_spec )
232250 state_shardings = nnx .get_named_sharding (state , mesh )
233251 data_shardings = self .get_data_shardings (mesh )
252+ eval_data_shardings = self .get_eval_data_shardings (mesh )
234253
235254 writer = max_utils .initialize_summary_writer (self .config )
236255 writer_thread = threading .Thread (target = _tensorboard_writer_worker , args = (writer , self .config ), daemon = True )
@@ -255,11 +274,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
255274 )
256275 p_eval_step = jax .jit (
257276 functools .partial (eval_step , scheduler = pipeline .scheduler , config = self .config ),
258- in_shardings = (state_shardings , data_shardings , None , None ),
277+ in_shardings = (state_shardings , eval_data_shardings , None , None ),
259278 out_shardings = (None , None ),
260279 )
261280
262281 rng = jax .random .key (self .config .seed )
282+ rng , eval_rng_key = jax .random .split (rng )
263283 start_step = 0
264284 last_step_completion = datetime .datetime .now ()
265285 local_metrics_file = open (self .config .metrics_file , "a" , encoding = "utf8" ) if self .config .metrics_file else None
@@ -305,24 +325,36 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
305325 # Re-create the iterator each time you start evaluation to reset it
306326 # This assumes your data loading logic can be called to get a fresh iterator.
307327 eval_data_iterator = self .load_dataset (mesh , is_training = False )
308- eval_rng = jax . random . key ( self . config . seed + step )
309- eval_metrics = []
328+ eval_rng = eval_rng_key
329+ eval_losses_by_timestep = {}
310330 # Loop indefinitely until the iterator is exhausted
311331 while True :
312332 try :
313333 with mesh :
314334 eval_batch = load_next_batch (eval_data_iterator , None , self .config )
315335 metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
316- eval_metrics .append (metrics ["scalar" ]["learning/eval_loss" ])
336+ loss = metrics ["scalar" ]["learning/eval_loss" ]
337+ timestep = int (eval_batch ["timesteps" ][0 ])
338+ if timestep not in eval_losses_by_timestep :
339+ eval_losses_by_timestep [timestep ] = []
340+ eval_losses_by_timestep [timestep ].append (loss )
317341 except StopIteration :
318342 # This block is executed when the iterator has no more data
319343 break
320344 # Check if any evaluation was actually performed
321- if eval_metrics :
322- eval_loss = jnp .mean (jnp .array (eval_metrics ))
323- max_logging .log (f"Step { step } , Eval loss: { eval_loss :.4f} " )
345+ if eval_losses_by_timestep :
346+ mean_per_timestep = []
347+ max_logging .log (f"Step { step } , calculating mean loss per timestep..." )
348+ for timestep , losses in sorted (eval_losses_by_timestep .items ()):
349+ losses = jnp .array (losses )
350+ losses = losses [: min (60 , len (losses ))]
351+ mean_loss = jnp .mean (losses )
352+ max_logging .log (f" Mean eval loss for timestep { timestep } : { mean_loss :.4f} " )
353+ mean_per_timestep .append (mean_loss )
354+ final_eval_loss = jnp .mean (jnp .array (mean_per_timestep ))
355+ max_logging .log (f"Step { step } , Final Average Eval loss: { final_eval_loss :.4f} " )
324356 if writer :
325- writer .add_scalar ("learning/eval_loss" , eval_loss , step )
357+ writer .add_scalar ("learning/eval_loss" , final_eval_loss , step )
326358 else :
327359 max_logging .log (f"Step { step } , evaluation dataset was empty." )
328360 example_batch = next_batch_future .result ()
@@ -394,12 +426,15 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
394426 """
395427 Computes the evaluation loss for a single batch without updating model weights.
396428 """
397- _ , new_rng , timestep_rng = jax .random .split (rng , num = 3 )
429+ _ , new_rng = jax .random .split (rng , num = 2 )
398430
399431 # This ensures the batch size is consistent, though it might be redundant
400432 # if the evaluation dataloader is already configured correctly.
401433 for k , v in data .items ():
402- data [k ] = v [: config .global_batch_size_to_train_on , :]
434+ if k != "timesteps" :
435+ data [k ] = v [: config .global_batch_size_to_train_on , :]
436+ else :
437+ data [k ] = v [: config .global_batch_size_to_train_on ]
403438
404439 # The loss function logic is identical to training. We are evaluating the model's
405440 # ability to perform its core training objective (e.g., denoising).
@@ -410,15 +445,8 @@ def loss_fn(params):
410445 # Prepare inputs
411446 latents = data ["latents" ].astype (config .weights_dtype )
412447 encoder_hidden_states = data ["encoder_hidden_states" ].astype (config .weights_dtype )
413- bsz = latents . shape [ 0 ]
448+ timesteps = data [ "timesteps" ]. astype ( "int64" )
414449
415- # Sample random timesteps and noise, just as in a training step
416- timesteps = jax .random .randint (
417- timestep_rng ,
418- (bsz ,),
419- 0 ,
420- scheduler .config .num_train_timesteps ,
421- )
422450 noise = jax .random .normal (key = new_rng , shape = latents .shape , dtype = latents .dtype )
423451 noisy_latents = scheduler .add_noise (scheduler_state , latents , noise , timesteps )
424452
@@ -427,6 +455,7 @@ def loss_fn(params):
427455 hidden_states = noisy_latents ,
428456 timestep = timesteps ,
429457 encoder_hidden_states = encoder_hidden_states ,
458+ deterministic = True ,
430459 )
431460
432461 # Calculate the loss against the target
@@ -447,4 +476,4 @@ def loss_fn(params):
447476 metrics = {"scalar" : {"learning/eval_loss" : loss }}
448477
449478 # Return the computed metrics and the new RNG key for the next eval step
450- return metrics , new_rng
479+ return metrics , new_rng ,
0 commit comments