@@ -209,7 +209,11 @@ def prepare_sample_eval(features):
209209
210210 def start_training (self ):
211211
212- pipeline = self .load_checkpoint ()
212+ pipeline , opt_state , step = self .load_checkpoint ()
213+ restore_args = {}
214+ if opt_state and step :
215+ restore_args = {"opt_state" : opt_state , "step" :step }
216+ del opt_state
213217 if self .config .enable_ssim :
214218 # Generate a sample before training to compare against generated sample after training.
215219 pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
@@ -228,7 +232,7 @@ def start_training(self):
228232 pipeline .scheduler_state = scheduler_state
229233 optimizer , learning_rate_scheduler = self ._create_optimizer (pipeline .transformer , self .config , 1e-5 )
230234 # Returns pipeline with trained transformer state
231- pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator )
235+ pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator , restore_args )
232236
233237 if self .config .enable_ssim :
234238 posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
@@ -280,18 +284,29 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr
280284 if writer :
281285 writer .add_scalar ("learning/eval_loss" , final_eval_loss , step )
282286
283- def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator ):
287+ def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator , restore_args : dict = {} ):
284288 mesh = pipeline .mesh
285289 graphdef , params , rest_of_state = nnx .split (pipeline .transformer , nnx .Param , ...)
286290
287291 with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
288292 state = TrainState .create (
289- apply_fn = graphdef .apply , params = params , tx = optimizer , graphdef = graphdef , rest_of_state = rest_of_state
290- )
293+ apply_fn = graphdef .apply , params = params , tx = optimizer , graphdef = graphdef , rest_of_state = rest_of_state )
294+ if restore_args :
295+ step = restore_args .get ("step" , 0 )
296+ max_logging .log (f"Restoring optimizer and resuming from step { step } " )
297+ state .replace (opt_state = restore_args .get ("opt_state" ), step = restore_args .get ("step" , 0 ))
298+ del restore_args ["opt_state" ]
299+ del optimizer
291300 state = jax .tree .map (_to_array , state )
292301 state_spec = nnx .get_partition_spec (state )
293302 state = jax .lax .with_sharding_constraint (state , state_spec )
294303 state_shardings = nnx .get_named_sharding (state , mesh )
304+ if jax .process_index () == 0 and restore_args :
305+ max_logging .log ("--- Optimizer State Sharding Spec (opt_state) ---" )
306+ # Use pprint for a clean, readable tree structure
307+ from pprint import pprint
308+ pprint (state_spec .opt_state )
309+ max_logging .log ("------------------------------------------------" )
295310 data_shardings = self .get_data_shardings (mesh )
296311 eval_data_shardings = self .get_eval_data_shardings (mesh )
297312
@@ -334,8 +349,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
334349 last_profiling_step = np .clip (
335350 first_profiling_step + self .config .profiler_steps - 1 , first_profiling_step , self .config .max_train_steps - 1
336351 )
337- # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint.
338- start_step = 0
352+ if restore_args .get ("step" ,0 ):
353+ max_logging .log (f"Resuming training from step { step } " )
354+ start_step = restore_args .get ("step" ,0 )
339355 per_device_tflops , _ , _ = WanTrainer .calculate_tflops (pipeline )
340356 scheduler_state = pipeline .scheduler_state
341357 example_batch = load_next_batch (train_data_iterator , None , self .config )
@@ -373,8 +389,11 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
373389 example_batch = next_batch_future .result ()
374390 if step != 0 and self .config .checkpoint_every != - 1 and step % self .config .checkpoint_every == 0 :
375391 max_logging .log (f"Saving checkpoint for step { step } " )
376- self .save_checkpoint (step , pipeline , state .params )
377-
392+ if self .config .save_optimizer :
393+ self .save_checkpoint (step , pipeline , state )
394+ else :
395+ self .save_checkpoint (step , pipeline , state .params )
396+
378397 _metrics_queue .put (None )
379398 writer_thread .join ()
380399 if writer :
0 commit comments