@@ -108,7 +108,7 @@ def get_data_shardings(self, mesh):
108108 data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding }
109109 return data_sharding
110110
111- def load_dataset (self , mesh ):
111+ def load_dataset (self , mesh , is_training = True ):
112112 # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
113113 # Image pre-training - txt2img 256px
114114 # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16
@@ -141,6 +141,7 @@ def prepare_sample(features):
141141 config .global_batch_size_to_load ,
142142 feature_description = feature_description ,
143143 prepare_sample_fn = prepare_sample ,
144+ is_training = is_training ,
144145 )
145146 return data_iterator
146147
@@ -155,20 +156,20 @@ def start_training(self):
155156 del pipeline .vae_cache
156157
157158 mesh = pipeline .mesh
158- data_iterator = self .load_dataset (mesh )
159+ train_data_iterator = self .load_dataset (mesh , is_training = True )
159160
160161 # Load FlowMatch scheduler
161162 scheduler , scheduler_state = self .create_scheduler ()
162163 pipeline .scheduler = scheduler
163164 pipeline .scheduler_state = scheduler_state
164165 optimizer , learning_rate_scheduler = self ._create_optimizer (pipeline .transformer , self .config , 1e-5 )
165166 # Returns pipeline with trained transformer state
166- pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , data_iterator )
167+ pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator )
167168
168169 posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
169170 print_ssim (pretrained_video_path , posttrained_video_path )
170171
171- def training_loop (self , pipeline , optimizer , learning_rate_scheduler , data_iterator ):
172+ def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator ):
172173 mesh = pipeline .mesh
173174 graphdef , params , rest_of_state = nnx .split (pipeline .transformer , nnx .Param , ...)
174175
@@ -203,6 +204,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
203204 out_shardings = (state_shardings , None , None , None ),
204205 donate_argnums = (0 ,),
205206 )
207+ p_eval_step = jax .jit (
208+ functools .partial (eval_step , scheduler = pipeline .scheduler , config = self .config ),
209+ in_shardings = (state_shardings , data_shardings , None , None ),
210+ out_shardings = (None , None ),
211+ )
212+
206213 rng = jax .random .key (self .config .seed )
207214 start_step = 0
208215 last_step_completion = datetime .datetime .now ()
@@ -219,13 +226,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
219226 per_device_tflops = self .calculate_tflops (pipeline )
220227
221228 scheduler_state = pipeline .scheduler_state
222- example_batch = load_next_batch (data_iterator , None , self .config )
229+ example_batch = load_next_batch (train_data_iterator , None , self .config )
223230 with ThreadPoolExecutor (max_workers = 1 ) as executor :
224231 for step in np .arange (start_step , self .config .max_train_steps ):
225232 if self .config .enable_profiler and step == first_profiling_step :
226233 max_utils .activate_profiler (self .config )
227234 start_step_time = datetime .datetime .now ()
228- next_batch_future = executor .submit (load_next_batch , data_iterator , example_batch , self .config )
235+ next_batch_future = executor .submit (load_next_batch , train_data_iterator , example_batch , self .config )
229236 with jax .profiler .StepTraceAnnotation ("train" , step_num = step ), pipeline .mesh , nn_partitioning .axis_rules (
230237 self .config .logical_axis_rules
231238 ):
@@ -241,6 +248,31 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
241248 )
242249 if self .config .write_metrics :
243250 train_utils .write_metrics (writer , local_metrics_file , running_gcs_metrics , train_metric , step , self .config )
251+
252+ if self .config .eval_every > 0 and (step + 1 ) % self .config .eval_every == 0 :
253+ # Re-create the iterator each time you start evaluation to reset it
254+ # This assumes your data loading logic can be called to get a fresh iterator.
255+ eval_data_iterator = self .load_dataset (mesh , is_training = False )
256+ eval_rng = jax .random .key (self .config .seed + step )
257+ eval_metrics = []
258+ # Loop indefinitely until the iterator is exhausted
259+ while True :
260+ try :
261+ with mesh :
262+ eval_batch = load_next_batch (eval_data_iterator , None , self .config )
263+ metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
264+ eval_metrics .append (metrics ["scalar" ]["learning/eval_loss" ])
265+ except StopIteration :
266+ # This block is executed when the iterator has no more data
267+ break
268+ # Check if any evaluation was actually performed
269+ if eval_metrics :
270+ eval_loss = jnp .mean (jnp .array (eval_metrics ))
271+ max_logging .log (f"Step { step } , Eval loss: { eval_loss :.4f} " )
272+ if writer :
273+ writer .add_scalar ("learning/eval_loss" , eval_loss , step )
274+ else :
275+ max_logging .log (f"Step { step } , evaluation dataset was empty." )
244276 example_batch = next_batch_future .result ()
245277
246278 _metrics_queue .put (None )
@@ -296,3 +328,62 @@ def loss_fn(params):
296328 new_state = state .apply_gradients (grads = grads )
297329 metrics = {"scalar" : {"learning/loss" : loss }, "scalars" : {}}
298330 return new_state , scheduler_state , metrics , new_rng
331+
332+ def eval_step (state , data , rng , scheduler_state , scheduler , config ):
333+ """
334+ Computes the evaluation loss for a single batch without updating model weights.
335+ """
336+ _ , new_rng , timestep_rng = jax .random .split (rng , num = 3 )
337+
338+ # This ensures the batch size is consistent, though it might be redundant
339+ # if the evaluation dataloader is already configured correctly.
340+ for k , v in data .items ():
341+ data [k ] = v [: config .global_batch_size_to_train_on , :]
342+
343+ # The loss function logic is identical to training. We are evaluating the model's
344+ # ability to perform its core training objective (e.g., denoising).
345+ def loss_fn (params ):
346+ # Reconstruct the model from its definition and parameters
347+ model = nnx .merge (state .graphdef , params , state .rest_of_state )
348+
349+ # Prepare inputs
350+ latents = data ["latents" ].astype (config .weights_dtype )
351+ encoder_hidden_states = data ["encoder_hidden_states" ].astype (config .weights_dtype )
352+ bsz = latents .shape [0 ]
353+
354+ # Sample random timesteps and noise, just as in a training step
355+ timesteps = jax .random .randint (
356+ timestep_rng ,
357+ (bsz ,),
358+ 0 ,
359+ scheduler .config .num_train_timesteps ,
360+ )
361+ noise = jax .random .normal (key = new_rng , shape = latents .shape , dtype = latents .dtype )
362+ noisy_latents = scheduler .add_noise (scheduler_state , latents , noise , timesteps )
363+
364+ # Get the model's prediction
365+ model_pred = model (
366+ hidden_states = noisy_latents ,
367+ timestep = timesteps ,
368+ encoder_hidden_states = encoder_hidden_states ,
369+ )
370+
371+ # Calculate the loss against the target
372+ training_target = scheduler .training_target (latents , noise , timesteps )
373+ training_weight = jnp .expand_dims (scheduler .training_weight (scheduler_state , timesteps ), axis = (1 , 2 , 3 , 4 ))
374+ loss = (training_target - model_pred ) ** 2
375+ loss = loss * training_weight
376+ loss = jnp .mean (loss )
377+
378+ return loss
379+
380+ # --- Key Difference from train_step ---
381+ # Directly compute the loss without calculating gradients.
382+ # The model's state.params are used but not updated.
383+ loss = loss_fn (state .params )
384+
385+ # Structure the metrics for logging and aggregation
386+ metrics = {"scalar" : {"learning/eval_loss" : loss }}
387+
388+ # Return the computed metrics and the new RNG key for the next eval step
389+ return metrics , new_rng
0 commit comments