Skip to content

Commit a60d235

Browse files
working training pipeline on v5p at num_frames=1
1 parent 05f0554 commit a60d235

1 file changed

Lines changed: 53 additions & 11 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414
limitations under the License.
1515
"""
1616

17+
import os
18+
import datetime
1719
import functools
1820
import numpy as np
1921
import jax.numpy as jnp
2022
import jax
2123
import jax.tree_util as jtu
2224
from flax import nnx
2325
from ..schedulers import FlaxEulerDiscreteScheduler
24-
from .. import max_utils
25-
from .. import max_logging
26+
from .. import max_utils, max_logging, train_utils
2627
from ..checkpointing.wan_checkpointer import (
2728
WanCheckpointer,
2829
WAN_CHECKPOINT
@@ -35,6 +36,8 @@ def __init__(self, config):
3536
if config.train_text_encoder:
3637
raise ValueError("this script currently doesn't support training text_encoders")
3738

39+
self.global_batch_size = self.config.per_device_batch_size * jax.device_count()
40+
3841
def post_training_steps(self, pipeline, params, train_states, msg=""):
3942
pass
4043

@@ -49,7 +52,8 @@ def create_scheduler(self, pipeline, params):
4952
return noise_scheduler, noise_scheduler_state
5053

5154
def calculate_tflops(self, pipeline):
52-
pass
55+
max_logging.log(f"WARNING : Calculting tflops is not implemented in Wan 2.1. Returning 0...")
56+
return 0
5357

5458
def load_dataset(self, pipeline):
5559
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
@@ -60,10 +64,9 @@ def load_dataset(self, pipeline):
6064
# prompt embeds shape: (1, 512, 4096)
6165
# For now, we will pass the same latents over and over
6266
# TODO - create a dataset
63-
global_batch_size = self.config.per_device_batch_size * jax.device_count()
64-
prompt_embeds = jax.random.normal(jax.random.key(self.config.seed), (global_batch_size, 512, 4096))
67+
prompt_embeds = jax.random.normal(jax.random.key(self.config.seed), (self.global_batch_size, 512, 4096))
6568
latents = pipeline.prepare_latents(
66-
global_batch_size,
69+
self.global_batch_size,
6770
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
6871
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
6972
height=self.config.height,
@@ -92,23 +95,61 @@ def start_training(self):
9295
#graphdef, state = nnx.plit((pipeline.transformer, optimizer))
9396
dummy_inputs = self.load_dataset(pipeline)
9497
dummy_inputs = tuple([jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs])
95-
9698
self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs)
9799

98100
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
99101

100102
graphdef, state = nnx.split((pipeline.transformer, optimizer))
103+
writer = max_utils.initialize_summary_writer(self.config)
104+
num_model_parameters = max_utils.calculate_num_params_from_pytree(state[0])
105+
max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer)
106+
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS", ""), writer)
107+
max_utils.add_config_to_summary_writer(self.config, writer)
108+
109+
if jax.process_index() == 0:
110+
max_logging.log("***** Running training *****")
111+
max_logging.log(f" Instantaneous batch size per device = {self.config.per_device_batch_size}")
112+
max_logging.log(f" Total train batch size (w. parallel & distributed) = {self.global_batch_size}")
113+
max_logging.log(f" Total optimization steps = {self.config.max_train_steps}")
114+
115+
101116
state = state.to_pure_dict()
102117
p_train_step = jax.jit(
103118
train_step,
104119
donate_argnums=(1,),
105120
)
106121
rng = jax.random.key(self.config.seed)
107122
start_step = 0
123+
last_step_completion = datetime.datetime.now()
124+
local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None
125+
running_gcs_metrics = [] if self.config.gcs_metrics else None
126+
first_profiling_step = self.config.skip_first_n_steps_for_profiler
127+
if self.config.enable_profiler and first_profiling_step >= self.config.max_train_steps:
128+
raise ValueError("Profiling requested but initial profiling step set past training final step")
129+
last_profiling_step = np.clip(
130+
first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1
131+
)
132+
# TODO - 0 needs to be changed to last step if continuing from an orbax checkpoint.
133+
start_step = 0
134+
per_device_tflops = self.calculate_tflops(pipeline)
135+
108136
for step in np.arange(start_step, self.config.max_train_steps):
109-
with pipeline.mesh:
110-
loss, state, rng = p_train_step(graphdef, state, data, rng)
111-
max_logging.log(f"loss: {loss}")
137+
if self.config.enable_profiler and step == first_profiling_step:
138+
max_utils.activate_profiler(self.config)
139+
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh:
140+
state, train_metric, rng = p_train_step(graphdef, state, data, rng)
141+
142+
new_time = datetime.datetime.now()
143+
144+
if self.config.enable_profiler and step == last_profiling_step:
145+
max_utils.deactivate_profiler(self.config)
146+
147+
train_utils.record_scalar_metrics(
148+
train_metric, new_time - last_step_completion, per_device_tflops, learning_rate_scheduler(step)
149+
)
150+
if self.config.write_metrics:
151+
train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config)
152+
last_step_completion = new_time
112153

113154
def train_step(graphdef, state, data, rng):
114155
return step_optimizer(graphdef, state, data, rng)
@@ -145,4 +186,5 @@ def loss_fn(model):
145186
optimizer.update(grads)
146187
state = nnx.state((model, optimizer))
147188
state = state.to_pure_dict()
148-
return loss, state, new_rng
189+
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
190+
return state, metrics, new_rng

0 commit comments

Comments
 (0)