1+ """
2+ Copyright 2025 Google LLC
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ https://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
16+
17+ from typing import Union , List
18+ from transformers import AutoTokenizer , UMT5EncoderModel
19+ import torch
20+ from ...models .wan .transformers .transformer_flux_wan_nnx import WanModel
21+ from ...models .wan .autoencoder_kl_wan import AutoencoderKLWan
22+ from ..pipeline_flax_utils import FlaxDiffusionPipeline
23+ from ...video_processor import VideoProcessor
24+ from ...schedulers import FlowMatchEulerDiscreteScheduler
25+
26+ class WanPipeline (FlaxDiffusionPipeline ):
27+
28+ def __init__ (
29+ self ,
30+ tokenizer : AutoTokenizer ,
31+ text_encoder : UMT5EncoderModel ,
32+ transformer : WanModel ,
33+ vae : AutoencoderKLWan ,
34+ scheduler : FlowMatchEulerDiscreteScheduler ,
35+ ):
36+ super ().__init__ ()
37+
38+ self .register_modules (
39+ vae = vae ,
40+ text_encoder = text_encoder ,
41+ tokenizer = tokenizer ,
42+ transformer = transformer ,
43+ scheduler = scheduler
44+ )
45+
46+ self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
47+ self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
48+ self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
49+
50+
51+ def _get_t5_prompt_embds (
52+ self ,
53+ prompt : Union [str , List [str ]] = None ,
54+ num_videos_per_prompt : int = 1 ,
55+ max_sequence_length : int = 226 ,
56+ ):
57+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
58+ batch_size = len (prompt )
59+
60+ text_inputs = self .tokenizer (
61+ prompt ,
62+ padding = "max_length" ,
63+ max_length = max_sequence_length ,
64+ truncation = True ,
65+ add_special_tokens = True ,
66+ return_attention_mask = True ,
67+ return_tensors = "pt" ,
68+ )
69+ text_input_ids , mask = text_inputs .input_ids , text_inputs .attention_mask
70+ seq_lens = mask .gt (0 ).sum (dim = 1 ).long ()
71+
72+ prompt_embeds = self .text_encoder (text_input_ids , mask ).last_hidden_state
73+ # prompt_embeds = prompt_embeds.to(dtype=dtype)
74+ prompt_embeds = [u [:v ] for u , v in zip (prompt_embeds , seq_lens )]
75+ prompt_embeds = torch .stack (
76+ [torch .cat ([u , u .new_zeros (max_sequence_length - u .size (0 ), u .size (1 ))]) for u in prompt_embeds ], dim = 0
77+ )
78+
79+ # duplicate text embeddings for each generation per prompt, using mps friendly method
80+ _ , seq_len , _ = prompt_embeds .shape
81+ prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
82+ prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
83+
84+ return prompt_embeds
0 commit comments