@@ -101,7 +101,7 @@ def get_data_shardings(self, mesh):
101101 data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding }
102102 return data_sharding
103103
104- def load_dataset (self , mesh ):
104+ def load_dataset (self , mesh , is_training = True ):
105105 # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
106106 # Image pre-training - txt2img 256px
107107 # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16
@@ -134,6 +134,7 @@ def prepare_sample(features):
134134 config .global_batch_size_to_load ,
135135 feature_description = feature_description ,
136136 prepare_sample_fn = prepare_sample ,
137+ is_training = is_training ,
137138 )
138139 return data_iterator
139140
@@ -145,20 +146,20 @@ def start_training(self):
145146 # Generate a sample before training to compare against generated sample after training.
146147 pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
147148 mesh = pipeline .mesh
148- data_iterator = self .load_dataset (mesh )
149+ train_data_iterator = self .load_dataset (mesh , is_training = True )
149150
150151 # Load FlowMatch scheduler
151152 scheduler , scheduler_state = self .create_scheduler ()
152153 pipeline .scheduler = scheduler
153154 pipeline .scheduler_state = scheduler_state
154155 optimizer , learning_rate_scheduler = self ._create_optimizer (pipeline .transformer , self .config , 1e-5 )
155156 # Returns pipeline with trained transformer state
156- pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , data_iterator )
157+ pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator )
157158
158159 posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
159160 print_ssim (pretrained_video_path , posttrained_video_path )
160161
161- def training_loop (self , pipeline , optimizer , learning_rate_scheduler , data_iterator ):
162+ def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator ):
162163 mesh = pipeline .mesh
163164 graphdef , params , rest_of_state = nnx .split (pipeline .transformer , nnx .Param , ...)
164165
@@ -193,6 +194,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
193194 out_shardings = (state_shardings , None , None , None ),
194195 donate_argnums = (0 ,),
195196 )
197+ p_eval_step = jax .jit (
198+ functools .partial (eval_step , scheduler = pipeline .scheduler , config = self .config ),
199+ in_shardings = (state_shardings , data_shardings , None , None ),
200+ out_shardings = (None , None ),
201+ )
202+
196203 rng = jax .random .key (self .config .seed )
197204 start_step = 0
198205 last_step_completion = datetime .datetime .now ()
@@ -209,13 +216,13 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
209216 per_device_tflops = self .calculate_tflops (pipeline )
210217
211218 scheduler_state = pipeline .scheduler_state
212- example_batch = load_next_batch (data_iterator , None , self .config )
219+ example_batch = load_next_batch (train_data_iterator , None , self .config )
213220 with ThreadPoolExecutor (max_workers = 1 ) as executor :
214221 for step in np .arange (start_step , self .config .max_train_steps ):
215222 if self .config .enable_profiler and step == first_profiling_step :
216223 max_utils .activate_profiler (self .config )
217224 start_step_time = datetime .datetime .now ()
218- next_batch_future = executor .submit (load_next_batch , data_iterator , example_batch , self .config )
225+ next_batch_future = executor .submit (load_next_batch , train_data_iterator , example_batch , self .config )
219226 with jax .profiler .StepTraceAnnotation ("train" , step_num = step ), pipeline .mesh , nn_partitioning .axis_rules (
220227 self .config .logical_axis_rules
221228 ):
@@ -231,6 +238,31 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
231238 )
232239 if self .config .write_metrics :
233240 train_utils .write_metrics (writer , local_metrics_file , running_gcs_metrics , train_metric , step , self .config )
241+
242+ if self .config .eval_every > 0 and (step + 1 ) % self .config .eval_every == 0 :
243+ # Re-create the iterator each time you start evaluation to reset it
244+ # This assumes your data loading logic can be called to get a fresh iterator.
245+ eval_data_iterator = self .load_dataset (mesh , is_training = False )
246+ eval_rng = jax .random .key (self .config .seed + step )
247+ eval_metrics = []
248+ # Loop indefinitely until the iterator is exhausted
249+ while True :
250+ try :
251+ with mesh :
252+ eval_batch = load_next_batch (eval_data_iterator , None , self .config )
253+ metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
254+ eval_metrics .append (metrics ["scalar" ]["learning/eval_loss" ])
255+ except StopIteration :
256+ # This block is executed when the iterator has no more data
257+ break
258+ # Check if any evaluation was actually performed
259+ if eval_metrics :
260+ eval_loss = jnp .mean (jnp .array (eval_metrics ))
261+ max_logging .log (f"Step { step } , Eval loss: { eval_loss :.4f} " )
262+ if writer :
263+ writer .add_scalar ("learning/eval_loss" , eval_loss , step )
264+ else :
265+ max_logging .log (f"Step { step } , evaluation dataset was empty." )
234266 example_batch = next_batch_future .result ()
235267
236268 _metrics_queue .put (None )
@@ -286,3 +318,62 @@ def loss_fn(params):
286318 new_state = state .apply_gradients (grads = grads )
287319 metrics = {"scalar" : {"learning/loss" : loss }, "scalars" : {}}
288320 return new_state , scheduler_state , metrics , new_rng
321+
322+ def eval_step (state , data , rng , scheduler_state , scheduler , config ):
323+ """
324+ Computes the evaluation loss for a single batch without updating model weights.
325+ """
326+ _ , new_rng , timestep_rng = jax .random .split (rng , num = 3 )
327+
328+ # This ensures the batch size is consistent, though it might be redundant
329+ # if the evaluation dataloader is already configured correctly.
330+ for k , v in data .items ():
331+ data [k ] = v [: config .global_batch_size_to_train_on , :]
332+
333+ # The loss function logic is identical to training. We are evaluating the model's
334+ # ability to perform its core training objective (e.g., denoising).
335+ def loss_fn (params ):
336+ # Reconstruct the model from its definition and parameters
337+ model = nnx .merge (state .graphdef , params , state .rest_of_state )
338+
339+ # Prepare inputs
340+ latents = data ["latents" ].astype (config .weights_dtype )
341+ encoder_hidden_states = data ["encoder_hidden_states" ].astype (config .weights_dtype )
342+ bsz = latents .shape [0 ]
343+
344+ # Sample random timesteps and noise, just as in a training step
345+ timesteps = jax .random .randint (
346+ timestep_rng ,
347+ (bsz ,),
348+ 0 ,
349+ scheduler .config .num_train_timesteps ,
350+ )
351+ noise = jax .random .normal (key = new_rng , shape = latents .shape , dtype = latents .dtype )
352+ noisy_latents = scheduler .add_noise (scheduler_state , latents , noise , timesteps )
353+
354+ # Get the model's prediction
355+ model_pred = model (
356+ hidden_states = noisy_latents ,
357+ timestep = timesteps ,
358+ encoder_hidden_states = encoder_hidden_states ,
359+ )
360+
361+ # Calculate the loss against the target
362+ training_target = scheduler .training_target (latents , noise , timesteps )
363+ training_weight = jnp .expand_dims (scheduler .training_weight (scheduler_state , timesteps ), axis = (1 , 2 , 3 , 4 ))
364+ loss = (training_target - model_pred ) ** 2
365+ loss = loss * training_weight
366+ loss = jnp .mean (loss )
367+
368+ return loss
369+
370+ # --- Key Difference from train_step ---
371+ # Directly compute the loss without calculating gradients.
372+ # The model's state.params are used but not updated.
373+ loss = loss_fn (state .params )
374+
375+ # Structure the metrics for logging and aggregation
376+ metrics = {"scalar" : {"learning/eval_loss" : loss }}
377+
378+ # Return the computed metrics and the new RNG key for the next eval step
379+ return metrics , new_rng
0 commit comments