1+ """
2+ Copyright 2025 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
16+
17+ import functools
18+ import numpy as np
19+ import jax .numpy as jnp
20+ import jax
21+ import jax .tree_util as jtu
22+ from flax import nnx
23+ from ..schedulers import FlaxEulerDiscreteScheduler
24+ from .. import max_utils
25+ from .. import max_logging
26+ from ..checkpointing .wan_checkpointer import (
27+ WanCheckpointer ,
28+ WAN_CHECKPOINT
29+ )
30+ from multihost_dataloading import _form_global_array
31+
32+ class WanTrainer (WanCheckpointer ):
33+ def __init__ (self , config ):
34+ WanCheckpointer .__init__ (self , config , WAN_CHECKPOINT )
35+ if config .train_text_encoder :
36+ raise ValueError ("this script currently doesn't support training text_encoders" )
37+
38+ def post_training_steps (self , pipeline , params , train_states , msg = "" ):
39+ pass
40+
41+ def create_scheduler (self , pipeline , params ):
42+ # TODO - set right scheduler
43+ noise_scheduler , noise_scheduler_state = FlaxEulerDiscreteScheduler .from_pretrained (
44+ pretrained_model_name_or_path = self .config .pretrained_model_name_or_path , subfolder = "scheduler" , dtype = jnp .float32
45+ )
46+ noise_scheduler_state = noise_scheduler .set_timesteps (
47+ state = noise_scheduler_state , num_inference_steps = self .config .num_inference_steps , timestep_spacing = "flux"
48+ )
49+ return noise_scheduler , noise_scheduler_state
50+
51+ def calculate_tflops (self , pipeline ):
52+ pass
53+
54+ def load_dataset (self , pipeline ):
55+ # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
56+ # Image pre-training - txt2img 256px
57+ # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16
58+ # Image-video joint training - stage 2. 480px images and 480px 5 sec videos at fps=16
59+ # Image-video joint training - stage final. 720px images and 720px 5 sec videos at fps=16
60+ # prompt embeds shape: (1, 512, 4096)
61+ # For now, we will pass the same latents over and over
62+ # TODO - create a dataset
63+ global_batch_size = self .config .per_device_batch_size * jax .device_count ()
64+ prompt_embeds = jax .random .normal (jax .random .key (self .config .seed ), (global_batch_size , 512 , 4096 ))
65+ latents = pipeline .prepare_latents (
66+ global_batch_size ,
67+ vae_scale_factor_temporal = pipeline .vae_scale_factor_temporal ,
68+ vae_scale_factor_spatial = pipeline .vae_scale_factor_spatial ,
69+ height = self .config .height ,
70+ width = self .config .width ,
71+ num_frames = self .config .num_frames ,
72+ num_channels_latents = pipeline .transformer .config .in_channels
73+ )
74+ return (latents , prompt_embeds )
75+
76+ def start_training (self ):
77+
78+ pipeline = self .load_checkpoint ()
79+ mesh = pipeline .mesh
80+
81+ optimizer , learning_rate_scheduler = self ._create_optimizer (pipeline .transformer , self .config , self .config .learning_rate )
82+
83+ # @nnx.jit
84+ # def create_transformer_state(transformer):
85+ # optimizer = self._create_optimizer(transformer, self.config, self.config.learning_rate)
86+ # breakpoint()
87+ # _, state = nnx.split((transformer, optimizer))
88+
89+ # with mesh:
90+ # create_transformer_state(pipeline.transformer)
91+
92+ #graphdef, state = nnx.plit((pipeline.transformer, optimizer))
93+ dummy_inputs = self .load_dataset (pipeline )
94+ dummy_inputs = tuple ([jtu .tree_map_with_path (functools .partial (_form_global_array , global_mesh = mesh ), input ) for input in dummy_inputs ])
95+
96+ self .training_loop (pipeline , optimizer , learning_rate_scheduler , dummy_inputs )
97+
98+ def training_loop (self , pipeline , optimizer , learning_rate_scheduler , data ):
99+
100+ graphdef , state = nnx .split ((pipeline .transformer , optimizer ))
101+ state = state .to_pure_dict ()
102+ p_train_step = jax .jit (
103+ train_step ,
104+ donate_argnums = (1 ,),
105+ )
106+ rng = jax .random .key (self .config .seed )
107+ start_step = 0
108+ for step in np .arange (start_step , self .config .max_train_steps ):
109+ with pipeline .mesh :
110+ loss , state , rng = p_train_step (graphdef , state , data , rng )
111+ max_logging .log (f"loss: { loss } " )
112+
113+ def train_step (graphdef , state , data , rng ):
114+ return step_optimizer (graphdef , state , data , rng )
115+
116+ def step_optimizer (graphdef , state , data , rng ):
117+ _ , new_rng = jax .random .split (rng )
118+ def loss_fn (model ):
119+ latents , prompt_embeds = data
120+ bsz = latents .shape [0 ]
121+ timesteps = jnp .array ([0 ] * bsz , dtype = jnp .int32 )
122+
123+ noise = jax .random .normal (
124+ key = new_rng ,
125+ shape = latents .shape ,
126+ dtype = latents .dtype
127+ )
128+
129+ # TODO - add noise here
130+
131+ model_pred = model (
132+ hidden_states = noise ,
133+ timestep = timesteps ,
134+ encoder_hidden_states = prompt_embeds ,
135+ is_uncond = jnp .array (False , dtype = jnp .bool_ ),
136+ slg_mask = jnp .zeros (1 , dtype = jnp .bool_ )
137+ )
138+ target = noise - latents
139+ loss = (target - model_pred ) ** 2
140+ loss = jnp .mean (loss )
141+ #breakpoint()
142+ return loss
143+ model , optimizer = nnx .merge (graphdef , state )
144+ loss , grads = nnx .value_and_grad (loss_fn )(model )
145+ optimizer .update (grads )
146+ state = nnx .state ((model , optimizer ))
147+ state = state .to_pure_dict ()
148+ return loss , state , new_rng
0 commit comments