1414 limitations under the License.
1515 """
1616
17+ import os
18+ import datetime
1719import functools
1820import numpy as np
1921import jax .numpy as jnp
2022import jax
2123import jax .tree_util as jtu
2224from flax import nnx
2325from ..schedulers import FlaxEulerDiscreteScheduler
24- from .. import max_utils
25- from .. import max_logging
26+ from .. import max_utils , max_logging , train_utils
2627from ..checkpointing .wan_checkpointer import (
2728 WanCheckpointer ,
2829 WAN_CHECKPOINT
@@ -35,6 +36,8 @@ def __init__(self, config):
3536 if config .train_text_encoder :
3637 raise ValueError ("this script currently doesn't support training text_encoders" )
3738
39+ self .global_batch_size = self .config .per_device_batch_size * jax .device_count ()
40+
3841 def post_training_steps (self , pipeline , params , train_states , msg = "" ):
3942 pass
4043
@@ -49,7 +52,8 @@ def create_scheduler(self, pipeline, params):
4952 return noise_scheduler , noise_scheduler_state
5053
5154 def calculate_tflops (self , pipeline ):
52- pass
55+ max_logging .log (f"WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0..." )
56+ return 0
5357
5458 def load_dataset (self , pipeline ):
5559 # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
@@ -60,10 +64,9 @@ def load_dataset(self, pipeline):
6064 # prompt embeds shape: (1, 512, 4096)
6165 # For now, we will pass the same latents over and over
6266 # TODO - create a dataset
63- global_batch_size = self .config .per_device_batch_size * jax .device_count ()
64- prompt_embeds = jax .random .normal (jax .random .key (self .config .seed ), (global_batch_size , 512 , 4096 ))
67+ prompt_embeds = jax .random .normal (jax .random .key (self .config .seed ), (self .global_batch_size , 512 , 4096 ))
6568 latents = pipeline .prepare_latents (
66- global_batch_size ,
69+ self . global_batch_size ,
6770 vae_scale_factor_temporal = pipeline .vae_scale_factor_temporal ,
6871 vae_scale_factor_spatial = pipeline .vae_scale_factor_spatial ,
6972 height = self .config .height ,
@@ -92,23 +95,61 @@ def start_training(self):
9295 #graphdef, state = nnx.plit((pipeline.transformer, optimizer))
9396 dummy_inputs = self .load_dataset (pipeline )
9497 dummy_inputs = tuple ([jtu .tree_map_with_path (functools .partial (_form_global_array , global_mesh = mesh ), input ) for input in dummy_inputs ])
95-
9698 self .training_loop (pipeline , optimizer , learning_rate_scheduler , dummy_inputs )
9799
98100 def training_loop (self , pipeline , optimizer , learning_rate_scheduler , data ):
99101
100102 graphdef , state = nnx .split ((pipeline .transformer , optimizer ))
103+ writer = max_utils .initialize_summary_writer (self .config )
104+ num_model_parameters = max_utils .calculate_num_params_from_pytree (state [0 ])
105+ max_utils .add_text_to_summary_writer ("number_model_parameters" , str (num_model_parameters ), writer )
106+ max_utils .add_text_to_summary_writer ("libtpu_init_args" , os .environ .get ("LIBTPU_INIT_ARGS" , "" ), writer )
107+ max_utils .add_config_to_summary_writer (self .config , writer )
108+
109+ if jax .process_index () == 0 :
110+ max_logging .log ("***** Running training *****" )
111+ max_logging .log (f" Instantaneous batch size per device = { self .config .per_device_batch_size } " )
112+ max_logging .log (f" Total train batch size (w. parallel & distributed) = { self .global_batch_size } " )
113+ max_logging .log (f" Total optimization steps = { self .config .max_train_steps } " )
114+
115+
101116 state = state .to_pure_dict ()
102117 p_train_step = jax .jit (
103118 train_step ,
104119 donate_argnums = (1 ,),
105120 )
106121 rng = jax .random .key (self .config .seed )
107122 start_step = 0
123+ last_step_completion = datetime .datetime .now ()
124+ local_metrics_file = open (self .config .metrics_file , "a" , encoding = "utf8" ) if self .config .metrics_file else None
125+ running_gcs_metrics = [] if self .config .gcs_metrics else None
126+ first_profiling_step = self .config .skip_first_n_steps_for_profiler
127+ if self .config .enable_profiler and first_profiling_step >= self .config .max_train_steps :
128+ raise ValueError ("Profiling requested but initial profiling step set past training final step" )
129+ last_profiling_step = np .clip (
130+ first_profiling_step + self .config .profiler_steps - 1 , first_profiling_step , self .config .max_train_steps - 1
131+ )
132+ # TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint.
133+ start_step = 0
134+ per_device_tflops = self .calculate_tflops (pipeline )
135+
108136 for step in np .arange (start_step , self .config .max_train_steps ):
109- with pipeline .mesh :
110- loss , state , rng = p_train_step (graphdef , state , data , rng )
111- max_logging .log (f"loss: { loss } " )
137+ if self .config .enable_profiler and step == first_profiling_step :
138+ max_utils .activate_profiler (self .config )
139+ with jax .profiler .StepTraceAnnotation ("train" , step_num = step ), pipeline .mesh :
140+ state , train_metric , rng = p_train_step (graphdef , state , data , rng )
141+
142+ new_time = datetime .datetime .now ()
143+
144+ if self .config .enable_profiler and step == last_profiling_step :
145+ max_utils .deactivate_profiler (self .config )
146+
147+ train_utils .record_scalar_metrics (
148+ train_metric , new_time - last_step_completion , per_device_tflops , learning_rate_scheduler (step )
149+ )
150+ if self .config .write_metrics :
151+ train_utils .write_metrics (writer , local_metrics_file , running_gcs_metrics , train_metric , step , self .config )
152+ last_step_completion = new_time
112153
113154def train_step (graphdef , state , data , rng ):
114155 return step_optimizer (graphdef , state , data , rng )
@@ -145,4 +186,5 @@ def loss_fn(model):
145186 optimizer .update (grads )
146187 state = nnx .state ((model , optimizer ))
147188 state = state .to_pure_dict ()
148- return loss , state , new_rng
189+ metrics = {"scalar" : {"learning/loss" : loss }, "scalars" : {}}
190+ return state , metrics , new_rng
0 commit comments