Skip to content

Commit 05f0554

Browse files
initial commit for wan training
1 parent b84fc34 commit 05f0554

7 files changed

Lines changed: 276 additions & 1 deletion

File tree

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from abc import ABC
18+
from flax import nnx
19+
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
20+
from ..pipelines.wan.wan_pipeline import WanPipeline
21+
from .. import max_logging, max_utils
22+
23+
WAN_CHECKPOINT = "WAN_CHECKPOINT"
24+
25+
class WanCheckpointer(ABC):
26+
def __init__(self, config, checkpoint_type):
27+
self.config = config
28+
self.checkpoint_type = checkpoint_type
29+
30+
self.checkpoint_manager = create_orbax_checkpoint_manager(
31+
self.config.checkpoint_dir,
32+
enable_checkpointing=True,
33+
save_interval_steps=1,
34+
checkpoint_type=checkpoint_type,
35+
dataset_type=config.dataset_type
36+
)
37+
38+
# @nnx.jit
39+
def _create_optimizer(self, model, config, learning_rate):
40+
learning_rate_scheduler = max_utils.create_learning_rate_schedule(
41+
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
42+
)
43+
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
44+
# tx = nnx.Optimizer(model, tx)
45+
46+
# _, state, rest_of_state = nnx.split((model, tx), ...)
47+
# nnx.update((model, tx), state, rest_of_state)
48+
49+
50+
return nnx.Optimizer(model, tx), learning_rate_scheduler
51+
52+
def load_wan_configs_from_orbax(self, step):
53+
max_logging.log("Restoring stable diffusion configs")
54+
if step is None:
55+
step = self.checkpoint_manager.latest_step()
56+
if step is None:
57+
return None
58+
59+
def load_diffusers_checkpoint(self):
60+
pipeline = WanPipeline.from_pretrained(self.config)
61+
return pipeline
62+
63+
def load_checkpoint(self, step=None):
64+
model_configs = self.load_wan_configs_from_orbax(step)
65+
66+
if model_configs:
67+
raise NotImplemented("model configs should not exist in orbax")
68+
else:
69+
pipeline = self.load_diffusers_checkpoint()
70+
71+
return pipeline

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ slg_end: 1.0
220220
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
221221
guidance_rescale: 0.0
222222
num_inference_steps: 30
223+
fps: 24
223224
save_final_checkpoint: False
224225

225226
# SDXL Lightning parameters

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ class AttentionOp(nn.Module):
524524
quant: Quant = None
525525

526526
def setup(self):
527+
self.dpa_layer = None
527528
if self.attention_kernel == "cudnn_flash_te":
528529
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
529530

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .wan_pipeline import WanPipeline

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
100100
try:
101101
state[path].value = jax.device_put(val, sharding)
102102
except:
103-
breakpoint()
103+
raise ValueError("value should exist.")
104104
state = nnx.from_flat_state(state)
105105

106106
wan_transformer = nnx.merge(graphdef, state, rest_of_state)

src/maxdiffusion/train_wan.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Sequence
18+
19+
import jax
20+
from absl import app
21+
from maxdiffusion import max_logging, pyconfig
22+
from maxdiffusion.train_utils import validate_train_config
23+
24+
def train(config):
25+
from maxdiffusion.trainers.wan_trainer import WanTrainer
26+
trainer = WanTrainer(config)
27+
trainer.start_training()
28+
29+
def main(argv: Sequence[str]) -> None:
30+
pyconfig.initialize(argv)
31+
config = pyconfig.config
32+
validate_train_config(config)
33+
max_logging.log(f"Found {jax.device_count()} devices.")
34+
train(config)
35+
36+
if __name__ == "__main__":
37+
app.run(main)
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import functools
18+
import numpy as np
19+
import jax.numpy as jnp
20+
import jax
21+
import jax.tree_util as jtu
22+
from flax import nnx
23+
from ..schedulers import FlaxEulerDiscreteScheduler
24+
from .. import max_utils
25+
from .. import max_logging
26+
from ..checkpointing.wan_checkpointer import (
27+
WanCheckpointer,
28+
WAN_CHECKPOINT
29+
)
30+
from multihost_dataloading import _form_global_array
31+
32+
class WanTrainer(WanCheckpointer):
33+
def __init__(self, config):
34+
WanCheckpointer.__init__(self, config, WAN_CHECKPOINT)
35+
if config.train_text_encoder:
36+
raise ValueError("this script currently doesn't support training text_encoders")
37+
38+
def post_training_steps(self, pipeline, params, train_states, msg=""):
39+
pass
40+
41+
def create_scheduler(self, pipeline, params):
42+
# TODO - set right scheduler
43+
noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained(
44+
pretrained_model_name_or_path=self.config.pretrained_model_name_or_path, subfolder="scheduler", dtype=jnp.float32
45+
)
46+
noise_scheduler_state = noise_scheduler.set_timesteps(
47+
state=noise_scheduler_state, num_inference_steps=self.config.num_inference_steps, timestep_spacing="flux"
48+
)
49+
return noise_scheduler, noise_scheduler_state
50+
51+
def calculate_tflops(self, pipeline):
52+
pass
53+
54+
def load_dataset(self, pipeline):
55+
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
56+
# Image pre-training - txt2img 256px
57+
# Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16
58+
# Image-video joint training - stage 2. 480px images and 480px 5 sec videos at fps=16
59+
# Image-video joint training - stage final. 720px images and 720px 5 sec videos at fps=16
60+
# prompt embeds shape: (1, 512, 4096)
61+
# For now, we will pass the same latents over and over
62+
# TODO - create a dataset
63+
global_batch_size = self.config.per_device_batch_size * jax.device_count()
64+
prompt_embeds = jax.random.normal(jax.random.key(self.config.seed), (global_batch_size, 512, 4096))
65+
latents = pipeline.prepare_latents(
66+
global_batch_size,
67+
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
68+
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
69+
height=self.config.height,
70+
width=self.config.width,
71+
num_frames=self.config.num_frames,
72+
num_channels_latents=pipeline.transformer.config.in_channels
73+
)
74+
return (latents, prompt_embeds)
75+
76+
def start_training(self):
77+
78+
pipeline = self.load_checkpoint()
79+
mesh = pipeline.mesh
80+
81+
optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, self.config.learning_rate)
82+
83+
# @nnx.jit
84+
# def create_transformer_state(transformer):
85+
# optimizer = self._create_optimizer(transformer, self.config, self.config.learning_rate)
86+
# breakpoint()
87+
# _, state = nnx.split((transformer, optimizer))
88+
89+
# with mesh:
90+
# create_transformer_state(pipeline.transformer)
91+
92+
#graphdef, state = nnx.plit((pipeline.transformer, optimizer))
93+
dummy_inputs = self.load_dataset(pipeline)
94+
dummy_inputs = tuple([jtu.tree_map_with_path(functools.partial(_form_global_array, global_mesh=mesh), input) for input in dummy_inputs])
95+
96+
self.training_loop(pipeline, optimizer, learning_rate_scheduler, dummy_inputs)
97+
98+
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
99+
100+
graphdef, state = nnx.split((pipeline.transformer, optimizer))
101+
state = state.to_pure_dict()
102+
p_train_step = jax.jit(
103+
train_step,
104+
donate_argnums=(1,),
105+
)
106+
rng = jax.random.key(self.config.seed)
107+
start_step = 0
108+
for step in np.arange(start_step, self.config.max_train_steps):
109+
with pipeline.mesh:
110+
loss, state, rng = p_train_step(graphdef, state, data, rng)
111+
max_logging.log(f"loss: {loss}")
112+
113+
def train_step(graphdef, state, data, rng):
114+
return step_optimizer(graphdef, state, data, rng)
115+
116+
def step_optimizer(graphdef, state, data, rng):
117+
_, new_rng = jax.random.split(rng)
118+
def loss_fn(model):
119+
latents, prompt_embeds = data
120+
bsz = latents.shape[0]
121+
timesteps = jnp.array([0] * bsz, dtype=jnp.int32)
122+
123+
noise = jax.random.normal(
124+
key=new_rng,
125+
shape=latents.shape,
126+
dtype=latents.dtype
127+
)
128+
129+
# TODO - add noise here
130+
131+
model_pred = model(
132+
hidden_states=noise,
133+
timestep=timesteps,
134+
encoder_hidden_states=prompt_embeds,
135+
is_uncond=jnp.array(False, dtype=jnp.bool_),
136+
slg_mask=jnp.zeros(1, dtype=jnp.bool_)
137+
)
138+
target = noise - latents
139+
loss = (target - model_pred) ** 2
140+
loss = jnp.mean(loss)
141+
#breakpoint()
142+
return loss
143+
model, optimizer = nnx.merge(graphdef, state)
144+
loss, grads = nnx.value_and_grad(loss_fn)(model)
145+
optimizer.update(grads)
146+
state = nnx.state((model, optimizer))
147+
state = state.to_pure_dict()
148+
return loss, state, new_rng

0 commit comments

Comments
 (0)