1313 See the License for the specific language governing permissions and
1414 limitations under the License.
1515"""
16+
1617import html
1718from typing import Callable , List , Union , Sequence , Optional
1819import time
3738 setup_initial_state ,
3839)
3940
41+
4042def basic_clean (text ):
41- text = ftfy .fix_text (text )
42- text = html .unescape (html .unescape (text ))
43- return text .strip ()
43+ text = ftfy .fix_text (text )
44+ text = html .unescape (html .unescape (text ))
45+ return text .strip ()
4446
4547
4648def whitespace_clean (text ):
47- text = re .sub (r"\s+" , " " , text )
48- text = text .strip ()
49- return text
49+ text = re .sub (r"\s+" , " " , text )
50+ text = text .strip ()
51+ return text
5052
5153
5254def prompt_clean (text ):
53- text = whitespace_clean (basic_clean (text ))
54- return text
55+ text = whitespace_clean (basic_clean (text ))
56+ return text
57+
5558
5659def _get_t5_prompt_embeds (
5760 tokenizer : AutoTokenizer ,
@@ -63,35 +66,36 @@ def _get_t5_prompt_embeds(
6366 dtype : Optional [torch .dtype ] = None ,
6467):
6568
66- prompt = [prompt ] if isinstance (prompt , str ) else prompt
67- prompt = [prompt_clean (u ) for u in prompt ]
68- batch_size = len (prompt )
69+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
70+ prompt = [prompt_clean (u ) for u in prompt ]
71+ batch_size = len (prompt )
72+
73+ text_inputs = tokenizer (
74+ prompt ,
75+ padding = "max_length" ,
76+ max_length = max_sequence_length ,
77+ truncation = True ,
78+ add_special_tokens = True ,
79+ return_attention_mask = True ,
80+ return_tensors = "pt" ,
81+ )
82+ text_input_ids , mask = text_inputs .input_ids , text_inputs .attention_mask
83+ seq_lens = mask .gt (0 ).sum (dim = 1 ).long ()
84+
85+ prompt_embeds = text_encoder (text_input_ids .to (device ), mask .to (device )).last_hidden_state
86+ prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
87+ prompt_embeds = [u [:v ] for u , v in zip (prompt_embeds , seq_lens )]
88+ prompt_embeds = torch .stack (
89+ [torch .cat ([u , u .new_zeros (max_sequence_length - u .size (0 ), u .size (1 ))]) for u in prompt_embeds ], dim = 0
90+ )
6991
70- text_inputs = tokenizer (
71- prompt ,
72- padding = "max_length" ,
73- max_length = max_sequence_length ,
74- truncation = True ,
75- add_special_tokens = True ,
76- return_attention_mask = True ,
77- return_tensors = "pt" ,
78- )
79- text_input_ids , mask = text_inputs .input_ids , text_inputs .attention_mask
80- seq_lens = mask .gt (0 ).sum (dim = 1 ).long ()
81-
82- prompt_embeds = text_encoder (text_input_ids .to (device ), mask .to (device )).last_hidden_state
83- prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
84- prompt_embeds = [u [:v ] for u , v in zip (prompt_embeds , seq_lens )]
85- prompt_embeds = torch .stack (
86- [torch .cat ([u , u .new_zeros (max_sequence_length - u .size (0 ), u .size (1 ))]) for u in prompt_embeds ], dim = 0
87- )
92+ # duplicate text embeddings for each generation per prompt, using mps friendly method
93+ _ , seq_len , _ = prompt_embeds .shape
94+ prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
95+ prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
8896
89- # duplicate text embeddings for each generation per prompt, using mps friendly method
90- _ , seq_len , _ = prompt_embeds .shape
91- prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
92- prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
97+ return prompt_embeds
9398
94- return prompt_embeds
9599
96100def encode_prompt (
97101 tokenizer : AutoTokenizer ,
@@ -106,77 +110,77 @@ def encode_prompt(
106110 device : Optional [torch .device ] = None ,
107111 dtype : Optional [torch .dtype ] = None ,
108112):
109- r"""
110- Encodes the prompt into text encoder hidden states.
111-
112- Args:
113- prompt (`str` or `List[str]`, *optional*):
114- prompt to be encoded
115- negative_prompt (`str` or `List[str]`, *optional*):
116- The prompt or prompts not to guide the image generation. If not defined, one has to pass
117- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
118- less than `1`).
119- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
120- Whether to use classifier free guidance or not.
121- num_videos_per_prompt (`int`, *optional*, defaults to 1):
122- Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
123- prompt_embeds (`torch.Tensor`, *optional*):
124- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
125- provided, text embeddings will be generated from `prompt` input argument.
126- negative_prompt_embeds (`torch.Tensor`, *optional*):
127- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
128- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
129- argument.
130- device: (`torch.device`, *optional*):
131- torch device
132- dtype: (`torch.dtype`, *optional*):
133- torch dtype
134- """
135-
136- prompt = [prompt ] if isinstance (prompt , str ) else prompt
137- if prompt is not None :
138- batch_size = len (prompt )
139- else :
140- batch_size = prompt_embeds .shape [0 ]
141-
142- if prompt_embeds is None :
143- prompt_embeds = _get_t5_prompt_embeds (
144- tokenizer = tokenizer ,
145- text_encoder = text_encoder ,
146- prompt = prompt ,
147- num_videos_per_prompt = num_videos_per_prompt ,
148- max_sequence_length = max_sequence_length ,
149- device = device ,
150- dtype = dtype ,
151- )
152-
153- if do_classifier_free_guidance and negative_prompt_embeds is None :
154- negative_prompt = negative_prompt or ""
155- negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
156-
157- if prompt is not None and type (prompt ) is not type (negative_prompt ):
158- raise TypeError (
159- f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
160- f" { type ( prompt ) } ."
161- )
162- elif batch_size != len ( negative_prompt ):
163- raise ValueError (
164- f"`negative_prompt`: { negative_prompt } has batch size { len ( negative_prompt ) } , but `prompt`: "
165- f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches "
166- " the batch size of `prompt`."
167- )
168-
169- negative_prompt_embeds = _get_t5_prompt_embeds (
170- tokenizer = tokenizer ,
171- text_encoder = text_encoder ,
172- prompt = negative_prompt ,
173- num_videos_per_prompt = num_videos_per_prompt ,
174- max_sequence_length = max_sequence_length ,
175- device = device ,
176- dtype = dtype ,
177- )
178-
179- return prompt_embeds , negative_prompt_embeds
113+ r"""
114+ Encodes the prompt into text encoder hidden states.
115+
116+ Args:
117+ prompt (`str` or `List[str]`, *optional*):
118+ prompt to be encoded
119+ negative_prompt (`str` or `List[str]`, *optional*):
120+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
121+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
122+ less than `1`).
123+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
124+ Whether to use classifier free guidance or not.
125+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
126+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
127+ prompt_embeds (`torch.Tensor`, *optional*):
128+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
129+ provided, text embeddings will be generated from `prompt` input argument.
130+ negative_prompt_embeds (`torch.Tensor`, *optional*):
131+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
132+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
133+ argument.
134+ device: (`torch.device`, *optional*):
135+ torch device
136+ dtype: (`torch.dtype`, *optional*):
137+ torch dtype
138+ """
139+
140+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
141+ if prompt is not None :
142+ batch_size = len (prompt )
143+ else :
144+ batch_size = prompt_embeds .shape [0 ]
145+
146+ if prompt_embeds is None :
147+ prompt_embeds = _get_t5_prompt_embeds (
148+ tokenizer = tokenizer ,
149+ text_encoder = text_encoder ,
150+ prompt = prompt ,
151+ num_videos_per_prompt = num_videos_per_prompt ,
152+ max_sequence_length = max_sequence_length ,
153+ device = device ,
154+ dtype = dtype ,
155+ )
156+
157+ if do_classifier_free_guidance and negative_prompt_embeds is None :
158+ negative_prompt = negative_prompt or ""
159+ negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
160+
161+ if prompt is not None and type (prompt ) is not type (negative_prompt ):
162+ raise TypeError (
163+ f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !=" f" { type ( prompt ) } . "
164+ )
165+ elif batch_size != len ( negative_prompt ):
166+ raise ValueError (
167+ f"`negative_prompt`: { negative_prompt } has batch size { len ( negative_prompt ) } , but `prompt`:"
168+ f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches "
169+ " the batch size of `prompt`. "
170+ )
171+
172+ negative_prompt_embeds = _get_t5_prompt_embeds (
173+ tokenizer = tokenizer ,
174+ text_encoder = text_encoder ,
175+ prompt = negative_prompt ,
176+ num_videos_per_prompt = num_videos_per_prompt ,
177+ max_sequence_length = max_sequence_length ,
178+ device = device ,
179+ dtype = dtype ,
180+ )
181+
182+ return prompt_embeds , negative_prompt_embeds
183+
180184
181185def run (config ):
182186 max_logging .log ("Wan 2.1 inference script" )
@@ -188,17 +192,15 @@ def run(config):
188192 global_batch_size = config .per_device_batch_size * jax .local_device_count ()
189193
190194 tokenizer = AutoTokenizer .from_pretrained (
191- config .pretrained_model_name_or_path , subfolder = "tokenizer" , dtype = config .weights_dtype
195+ config .pretrained_model_name_or_path , subfolder = "tokenizer" , dtype = config .weights_dtype
192196 )
193197 text_encoder = UMT5EncoderModel .from_pretrained (
194- config .pretrained_model_name_or_path , subfolder = "text_encoder" ,
198+ config .pretrained_model_name_or_path ,
199+ subfolder = "text_encoder" ,
195200 )
196201 s0 = time .perf_counter ()
197202 prompt_embeds , negative_prompt_embeds = encode_prompt (
198- tokenizer = tokenizer ,
199- text_encoder = text_encoder ,
200- prompt = config .prompt ,
201- negative_prompt = config .negative_prompt
203+ tokenizer = tokenizer , text_encoder = text_encoder , prompt = config .prompt , negative_prompt = config .negative_prompt
202204 )
203205 max_logging .log (f"text encoding time: { (time .perf_counter () - s0 )} " )
204206
@@ -209,20 +211,15 @@ def run(config):
209211 # )
210212 # breakpoint()
211213
212- pipeline , params = WanPipeline .from_pretrained (
213- config .pretrained_model_name_or_path ,
214- vae = None ,
215- transformer = None
216- )
217-
218- #wan_transformer = WanModel(rngs=nnx.Rngs(config.seed))
219-
214+ pipeline , params = WanPipeline .from_pretrained (config .pretrained_model_name_or_path , vae = None , transformer = None )
220215
216+ # wan_transformer = WanModel(rngs=nnx.Rngs(config.seed))
221217
222218
223219def main (argv : Sequence [str ]) -> None :
224220 pyconfig .initialize (argv )
225221 run (pyconfig .config )
226222
223+
227224if __name__ == "__main__" :
228225 app .run (main )
0 commit comments