|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import List, Union, Optional |
| 16 | +import numpy as np |
| 17 | +import jax |
| 18 | +from jax.sharding import Mesh, PositionalSharding |
| 19 | +from flax import nnx |
| 20 | +from ...pyconfig import HyperParameters |
| 21 | +from ... import max_utils |
| 22 | +from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae |
| 23 | +from ...models.wan.transformers.transformer_wan import WanModel |
| 24 | +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache |
| 25 | +from maxdiffusion.video_processor import VideoProcessor |
| 26 | +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState |
| 27 | +from transformers import AutoTokenizer, UMT5EncoderModel |
| 28 | +import ftfy |
| 29 | +import html |
| 30 | +import re |
| 31 | +import torch |
| 32 | + |
| 33 | +def basic_clean(text): |
| 34 | + text = ftfy.fix_text(text) |
| 35 | + text = html.unescape(html.unescape(text)) |
| 36 | + return text.strip() |
| 37 | + |
| 38 | + |
| 39 | +def whitespace_clean(text): |
| 40 | + text = re.sub(r"\s+", " ", text) |
| 41 | + text = text.strip() |
| 42 | + return text |
| 43 | + |
| 44 | + |
| 45 | +def prompt_clean(text): |
| 46 | + text = whitespace_clean(basic_clean(text)) |
| 47 | + return text |
| 48 | + |
| 49 | +class WanPipeline: |
| 50 | + r""" |
| 51 | + Pipeline for text-to-video generation using Wan. |
| 52 | +
|
| 53 | + tokenizer ([`T5Tokenizer`]): |
| 54 | + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), |
| 55 | + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. |
| 56 | + text_encoder ([`T5EncoderModel`]): |
| 57 | + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically |
| 58 | + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. |
| 59 | + transformer ([`WanModel`]): |
| 60 | + Conditional Transformer to denoise the input latents. |
| 61 | + scheduler ([`FlaxUniPCMultistepScheduler`]): |
| 62 | + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. |
| 63 | + vae ([`AutoencoderKLWan`]): |
| 64 | + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. |
| 65 | + """ |
| 66 | + def __init__( |
| 67 | + self, |
| 68 | + tokenizer: AutoTokenizer, |
| 69 | + text_encoder: UMT5EncoderModel, |
| 70 | + transformer: WanModel, |
| 71 | + vae: AutoencoderKLWan, |
| 72 | + vae_cache: AutoencoderKLWanCache, |
| 73 | + scheduler: FlaxUniPCMultistepScheduler, |
| 74 | + scheduler_state: UniPCMultistepSchedulerState, |
| 75 | + devices_array: np.array, |
| 76 | + mesh: Mesh, |
| 77 | + config: HyperParameters |
| 78 | + ): |
| 79 | + self.tokenizer = tokenizer |
| 80 | + self.text_encoder = text_encoder |
| 81 | + self.transformer = transformer |
| 82 | + self.vae = vae |
| 83 | + self.vae_cache = vae_cache |
| 84 | + self.scheduler = scheduler |
| 85 | + self.scheduler_state = scheduler_state |
| 86 | + self.devices_array = devices_array |
| 87 | + self.mesh = mesh |
| 88 | + self.config = config |
| 89 | + |
| 90 | + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 |
| 91 | + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 |
| 92 | + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) |
| 93 | + |
| 94 | + @classmethod |
| 95 | + def load_vae(cls, rngs: nnx.Rngs, config: HyperParameters): |
| 96 | + wan_vae = AutoencoderKLWan.from_config( |
| 97 | + config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs |
| 98 | + ) |
| 99 | + vae_cache = AutoencoderKLWanCache(wan_vae) |
| 100 | + |
| 101 | + graphdef, state = nnx.split(wan_vae, nnx.Param) |
| 102 | + params = state.to_pure_dict() |
| 103 | + # This replaces random params with the model. |
| 104 | + params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") |
| 105 | + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) |
| 106 | + wan_vae = nnx.merge(graphdef, params) |
| 107 | + |
| 108 | + return wan_vae, vae_cache |
| 109 | + |
| 110 | + @classmethod |
| 111 | + def load_text_encoder(cls, config: HyperParameters): |
| 112 | + text_encoder = UMT5EncoderModel.from_pretrained( |
| 113 | + config.pretrained_model_name_or_path, |
| 114 | + subfolder="text_encoder", |
| 115 | + ) |
| 116 | + return text_encoder |
| 117 | + |
| 118 | + @classmethod |
| 119 | + def load_tokenizer(cls, config: HyperParameters): |
| 120 | + tokenizer = AutoTokenizer.from_pretrained( |
| 121 | + config.pretrained_model_name_or_path, |
| 122 | + subfolder="tokenizer", |
| 123 | + ) |
| 124 | + return tokenizer |
| 125 | + |
| 126 | + @classmethod |
| 127 | + def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): |
| 128 | + wan_transformer = WanModel.from_config( |
| 129 | + config.pretrained_model_name_or_path, |
| 130 | + subfolder="transformer", |
| 131 | + rngs=rngs, |
| 132 | + attention=config.attention, |
| 133 | + mesh=mesh, |
| 134 | + dtype=config.activations_dtype, |
| 135 | + weights_dtype=config.weights_dtype |
| 136 | + ) |
| 137 | + graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) |
| 138 | + params = state.to_pure_dict() |
| 139 | + del state |
| 140 | + params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") |
| 141 | + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) |
| 142 | + params = jax.device_put(params, PositionalSharding(devices_array).replicate()) |
| 143 | + wan_transformer = nnx.merge(graphdef, params, rest_of_state) |
| 144 | + return wan_transformer |
| 145 | + |
| 146 | + @classmethod |
| 147 | + def load_scheduler(cls, config): |
| 148 | + scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( |
| 149 | + config.pretrained_model_name_or_path, |
| 150 | + subfolder="scheduler", |
| 151 | + flow_shift=config.flow_shift # 5.0 for 720p, 3.0 for 480p |
| 152 | + ) |
| 153 | + return scheduler, scheduler_state |
| 154 | + |
| 155 | + @classmethod |
| 156 | + def from_pretrained(cls, config: HyperParameters): |
| 157 | + devices_array = max_utils.create_device_mesh(config) |
| 158 | + mesh = Mesh(devices_array, config.mesh_axes) |
| 159 | + rng = jax.random.key(config.seed) |
| 160 | + rngs = nnx.Rngs(rng) |
| 161 | + |
| 162 | + wan_vae, vae_cache = cls.load_vae(rngs=rngs, config=config) |
| 163 | + text_encoder = cls.load_text_encoder(config=config) |
| 164 | + tokenizer = cls.load_tokenizer(config=config) |
| 165 | + transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) |
| 166 | + scheduler, scheduler_state = cls.load_scheduler(config=config) |
| 167 | + |
| 168 | + return WanPipeline( |
| 169 | + tokenizer=tokenizer, |
| 170 | + text_encoder=text_encoder, |
| 171 | + transformer=transformer, |
| 172 | + vae=wan_vae, |
| 173 | + vae_cache=vae_cache, |
| 174 | + scheduler=scheduler, |
| 175 | + scheduler_state=scheduler_state, |
| 176 | + devices_array=devices_array, |
| 177 | + mesh=mesh, |
| 178 | + config=config |
| 179 | + ) |
| 180 | + |
| 181 | + def _get_t5_prompt_embeds( |
| 182 | + self, |
| 183 | + prompt: Union[str, List[str]] = None, |
| 184 | + num_videos_per_prompt: int = 1, |
| 185 | + max_sequence_length: int = 226, |
| 186 | + ): |
| 187 | + prompt = [prompt] if isinstance(prompt, str) else prompt |
| 188 | + prompt = [prompt_clean(u) for u in prompt] |
| 189 | + batch_size = len(prompt) |
| 190 | + |
| 191 | + text_inputs = self.tokenizer( |
| 192 | + prompt, |
| 193 | + padding="max_length", |
| 194 | + max_length=max_sequence_length, |
| 195 | + truncation=True, |
| 196 | + add_special_tokens=True, |
| 197 | + return_attention_mask=True, |
| 198 | + return_tensors="pt", |
| 199 | + ) |
| 200 | + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask |
| 201 | + seq_lens = mask.gt(0).sum(dim=1).long() |
| 202 | + prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state |
| 203 | + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] |
| 204 | + prompt_embeds = torch.stack( |
| 205 | + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 |
| 206 | + ) |
| 207 | + |
| 208 | + # duplicate text embeddings for each generation per prompt, using mps friendly method |
| 209 | + _, seq_len, _ = prompt_embeds.shape |
| 210 | + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) |
| 211 | + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) |
| 212 | + |
| 213 | + return prompt_embeds |
| 214 | + |
| 215 | + def encode_prompt( |
| 216 | + self, |
| 217 | + prompt: Union[str, List[str]], |
| 218 | + negative_prompt: Optional[Union[str, List[str]]] = None, |
| 219 | + num_videos_per_prompt: int = 1, |
| 220 | + max_sequence_length: int = 226, |
| 221 | + ): |
| 222 | + |
| 223 | + |
| 224 | + |
0 commit comments