Skip to content

Commit 716598b

Browse files
wip - building pipeline and gen code.
1 parent 38bea20 commit 716598b

3 files changed

Lines changed: 254 additions & 1 deletion

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,9 @@ do_classifier_free_guidance: True
217217
guidance_scale: 3.5
218218
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
219219
guidance_rescale: 0.0
220-
num_inference_steps: 50
220+
num_inference_steps: 30
221221
save_final_checkpoint: False
222+
flow_shift: 5.0
222223

223224
# SDXL Lightning parameters
224225
lightning_from_pt: True

src/maxdiffusion/generate_wan.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Sequence
16+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
17+
from maxdiffusion import pyconfig
18+
from absl import app
19+
20+
def run(config):
21+
pipeline = WanPipeline.from_pretrained(config)
22+
23+
def main(argv: Sequence[str]) -> None:
24+
pyconfig.initialize(argv)
25+
run(pyconfig.config)
26+
27+
if __name__ == "__main__":
28+
app.run(main)
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Union, Optional
16+
import numpy as np
17+
import jax
18+
from jax.sharding import Mesh, PositionalSharding
19+
from flax import nnx
20+
from ...pyconfig import HyperParameters
21+
from ... import max_utils
22+
from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae
23+
from ...models.wan.transformers.transformer_wan import WanModel
24+
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache
25+
from maxdiffusion.video_processor import VideoProcessor
26+
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
27+
from transformers import AutoTokenizer, UMT5EncoderModel
28+
import ftfy
29+
import html
30+
import re
31+
import torch
32+
33+
def basic_clean(text):
34+
text = ftfy.fix_text(text)
35+
text = html.unescape(html.unescape(text))
36+
return text.strip()
37+
38+
39+
def whitespace_clean(text):
40+
text = re.sub(r"\s+", " ", text)
41+
text = text.strip()
42+
return text
43+
44+
45+
def prompt_clean(text):
46+
text = whitespace_clean(basic_clean(text))
47+
return text
48+
49+
class WanPipeline:
50+
r"""
51+
Pipeline for text-to-video generation using Wan.
52+
53+
tokenizer ([`T5Tokenizer`]):
54+
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
55+
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
56+
text_encoder ([`T5EncoderModel`]):
57+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
58+
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
59+
transformer ([`WanModel`]):
60+
Conditional Transformer to denoise the input latents.
61+
scheduler ([`FlaxUniPCMultistepScheduler`]):
62+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
63+
vae ([`AutoencoderKLWan`]):
64+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
65+
"""
66+
def __init__(
67+
self,
68+
tokenizer: AutoTokenizer,
69+
text_encoder: UMT5EncoderModel,
70+
transformer: WanModel,
71+
vae: AutoencoderKLWan,
72+
vae_cache: AutoencoderKLWanCache,
73+
scheduler: FlaxUniPCMultistepScheduler,
74+
scheduler_state: UniPCMultistepSchedulerState,
75+
devices_array: np.array,
76+
mesh: Mesh,
77+
config: HyperParameters
78+
):
79+
self.tokenizer = tokenizer
80+
self.text_encoder = text_encoder
81+
self.transformer = transformer
82+
self.vae = vae
83+
self.vae_cache = vae_cache
84+
self.scheduler = scheduler
85+
self.scheduler_state = scheduler_state
86+
self.devices_array = devices_array
87+
self.mesh = mesh
88+
self.config = config
89+
90+
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
91+
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
92+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
93+
94+
@classmethod
95+
def load_vae(cls, rngs: nnx.Rngs, config: HyperParameters):
96+
wan_vae = AutoencoderKLWan.from_config(
97+
config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs
98+
)
99+
vae_cache = AutoencoderKLWanCache(wan_vae)
100+
101+
graphdef, state = nnx.split(wan_vae, nnx.Param)
102+
params = state.to_pure_dict()
103+
# This replaces random params with the model.
104+
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
105+
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
106+
wan_vae = nnx.merge(graphdef, params)
107+
108+
return wan_vae, vae_cache
109+
110+
@classmethod
111+
def load_text_encoder(cls, config: HyperParameters):
112+
text_encoder = UMT5EncoderModel.from_pretrained(
113+
config.pretrained_model_name_or_path,
114+
subfolder="text_encoder",
115+
)
116+
return text_encoder
117+
118+
@classmethod
119+
def load_tokenizer(cls, config: HyperParameters):
120+
tokenizer = AutoTokenizer.from_pretrained(
121+
config.pretrained_model_name_or_path,
122+
subfolder="tokenizer",
123+
)
124+
return tokenizer
125+
126+
@classmethod
127+
def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
128+
wan_transformer = WanModel.from_config(
129+
config.pretrained_model_name_or_path,
130+
subfolder="transformer",
131+
rngs=rngs,
132+
attention=config.attention,
133+
mesh=mesh,
134+
dtype=config.activations_dtype,
135+
weights_dtype=config.weights_dtype
136+
)
137+
graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
138+
params = state.to_pure_dict()
139+
del state
140+
params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu")
141+
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
142+
params = jax.device_put(params, PositionalSharding(devices_array).replicate())
143+
wan_transformer = nnx.merge(graphdef, params, rest_of_state)
144+
return wan_transformer
145+
146+
@classmethod
147+
def load_scheduler(cls, config):
148+
scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained(
149+
config.pretrained_model_name_or_path,
150+
subfolder="scheduler",
151+
flow_shift=config.flow_shift # 5.0 for 720p, 3.0 for 480p
152+
)
153+
return scheduler, scheduler_state
154+
155+
@classmethod
156+
def from_pretrained(cls, config: HyperParameters):
157+
devices_array = max_utils.create_device_mesh(config)
158+
mesh = Mesh(devices_array, config.mesh_axes)
159+
rng = jax.random.key(config.seed)
160+
rngs = nnx.Rngs(rng)
161+
162+
wan_vae, vae_cache = cls.load_vae(rngs=rngs, config=config)
163+
text_encoder = cls.load_text_encoder(config=config)
164+
tokenizer = cls.load_tokenizer(config=config)
165+
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
166+
scheduler, scheduler_state = cls.load_scheduler(config=config)
167+
168+
return WanPipeline(
169+
tokenizer=tokenizer,
170+
text_encoder=text_encoder,
171+
transformer=transformer,
172+
vae=wan_vae,
173+
vae_cache=vae_cache,
174+
scheduler=scheduler,
175+
scheduler_state=scheduler_state,
176+
devices_array=devices_array,
177+
mesh=mesh,
178+
config=config
179+
)
180+
181+
def _get_t5_prompt_embeds(
182+
self,
183+
prompt: Union[str, List[str]] = None,
184+
num_videos_per_prompt: int = 1,
185+
max_sequence_length: int = 226,
186+
):
187+
prompt = [prompt] if isinstance(prompt, str) else prompt
188+
prompt = [prompt_clean(u) for u in prompt]
189+
batch_size = len(prompt)
190+
191+
text_inputs = self.tokenizer(
192+
prompt,
193+
padding="max_length",
194+
max_length=max_sequence_length,
195+
truncation=True,
196+
add_special_tokens=True,
197+
return_attention_mask=True,
198+
return_tensors="pt",
199+
)
200+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
201+
seq_lens = mask.gt(0).sum(dim=1).long()
202+
prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
203+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
204+
prompt_embeds = torch.stack(
205+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
206+
)
207+
208+
# duplicate text embeddings for each generation per prompt, using mps friendly method
209+
_, seq_len, _ = prompt_embeds.shape
210+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
211+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
212+
213+
return prompt_embeds
214+
215+
def encode_prompt(
216+
self,
217+
prompt: Union[str, List[str]],
218+
negative_prompt: Optional[Union[str, List[str]]] = None,
219+
num_videos_per_prompt: int = 1,
220+
max_sequence_length: int = 226,
221+
):
222+
223+
224+

0 commit comments

Comments
 (0)