-
Notifications
You must be signed in to change notification settings - Fork 69
Expand file tree
/
Copy pathwan_pipeline.py
More file actions
651 lines (569 loc) · 24.3 KB
/
wan_pipeline.py
File metadata and controls
651 lines (569 loc) · 24.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union, Optional
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import flax
import flax.linen as nn
from flax import nnx
from flax.linen import partitioning as nn_partitioning
from ...pyconfig import HyperParameters
from ... import max_logging
from ... import max_utils
from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated
from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae
from ...models.wan.transformers.transformer_wan import WanModel
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache
from maxdiffusion.video_processor import VideoProcessor
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
from transformers import AutoTokenizer, UMT5EncoderModel
from maxdiffusion.utils.import_utils import is_ftfy_available
from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs
import html
import re
import torch
import qwix
def basic_clean(text):
if is_ftfy_available():
import ftfy
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.VariableState:
vs.sharding_rules = logical_axis_rules
return vs
# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device.
def create_sharded_logical_transformer(
devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None
):
def create_model(rngs: nnx.Rngs, wan_config: dict):
wan_transformer = WanModel(**wan_config, rngs=rngs)
return wan_transformer
# 1. Load config.
if restored_checkpoint:
wan_config = restored_checkpoint["wan_config"]
else:
wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer")
wan_config["mesh"] = mesh
wan_config["dtype"] = config.activations_dtype
wan_config["weights_dtype"] = config.weights_dtype
wan_config["attention"] = config.attention
wan_config["precision"] = get_precision(config)
wan_config["flash_block_sizes"] = get_flash_block_sizes(config)
wan_config["remat_policy"] = config.remat_policy
wan_config["names_which_can_be_saved"] = config.names_which_can_be_saved
wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded
wan_config["flash_min_seq_length"] = config.flash_min_seq_length
wan_config["dropout"] = config.dropout
# 2. eval_shape - will not use flops or create weights on device
# thus not using HBM memory.
p_model_factory = partial(create_model, wan_config=wan_config)
wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs)
graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...)
# 3. retrieve the state shardings, mapping logical names to mesh axis names.
logical_state_spec = nnx.get_partition_spec(state)
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
params = state.to_pure_dict()
state = dict(nnx.to_flat_state(state))
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
# This helps with loading sharded weights directly into the accelerators without fist copying them
# all to one device and then distributing them, thus using low HBM memory.
if restored_checkpoint:
params = restored_checkpoint["wan_state"]
else:
params = load_wan_transformer(
config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"]
)
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
for path, val in flax.traverse_util.flatten_dict(params).items():
if restored_checkpoint:
path = path[:-1]
sharding = logical_state_sharding[path].value
state[path].value = device_put_replicated(val, sharding)
state = nnx.from_flat_state(state)
wan_transformer = nnx.merge(graphdef, state, rest_of_state)
return wan_transformer
@nnx.jit(static_argnums=(1,), donate_argnums=(0,))
def create_sharded_logical_model(model, logical_axis_rules):
graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...)
p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=logical_axis_rules)
state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState))
pspecs = nnx.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
model = nnx.merge(graphdef, sharded_state, rest_of_state)
return model
class WanPipeline:
r"""
Pipeline for text-to-video generation using Wan.
tokenizer ([`T5Tokenizer`]):
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanModel`]):
Conditional Transformer to denoise the input latents.
scheduler ([`FlaxUniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanModel,
vae: AutoencoderKLWan,
vae_cache: AutoencoderKLWanCache,
scheduler: FlaxUniPCMultistepScheduler,
scheduler_state: UniPCMultistepSchedulerState,
devices_array: np.array,
mesh: Mesh,
config: HyperParameters,
):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.transformer = transformer
self.vae = vae
self.vae_cache = vae_cache
self.scheduler = scheduler
self.scheduler_state = scheduler_state
self.devices_array = devices_array
self.mesh = mesh
self.config = config
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.p_run_inference = None
@classmethod
def load_text_encoder(cls, config: HyperParameters):
text_encoder = UMT5EncoderModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="text_encoder",
)
return text_encoder
@classmethod
def load_tokenizer(cls, config: HyperParameters):
tokenizer = AutoTokenizer.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="tokenizer",
)
return tokenizer
@classmethod
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
def create_model(rngs: nnx.Rngs, config: HyperParameters):
wan_vae = AutoencoderKLWan.from_config(
config.pretrained_model_name_or_path,
subfolder="vae",
rngs=rngs,
mesh=mesh,
dtype=config.activations_dtype,
weights_dtype=config.weights_dtype,
)
return wan_vae
# 1. eval shape
p_model_factory = partial(create_model, config=config)
wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs)
graphdef, state = nnx.split(wan_vae, nnx.Param)
# 2. retrieve the state shardings, mapping logical names to mesh axis names.
logical_state_spec = nnx.get_partition_spec(state)
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
params = state.to_pure_dict()
state = dict(nnx.to_flat_state(state))
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
# This helps with loading sharded weights directly into the accelerators without fist copying them
# all to one device and then distributing them, thus using low HBM memory.
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
for path, val in flax.traverse_util.flatten_dict(params).items():
sharding = logical_state_sharding[path].value
if config.replicate_vae:
sharding = NamedSharding(mesh, P())
state[path].value = device_put_replicated(val, sharding)
state = nnx.from_flat_state(state)
wan_vae = nnx.merge(graphdef, state)
vae_cache = AutoencoderKLWanCache(wan_vae)
return wan_vae, vae_cache
@classmethod
def get_basic_config(cls, dtype):
rules = [
qwix.QtRule(
module_path=".*", # Apply to all modules
weight_qtype=dtype,
act_qtype=dtype,
)
]
return rules
@classmethod
def get_fp8_config(cls, quantization_calibration_method: str):
"""
fp8 config rules with per-tensor calibration.
FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api):
The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice.
"""
rules = [
qwix.QtRule(
module_path=".*", # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=quantization_calibration_method,
act_calibration_method=quantization_calibration_method,
bwd_calibration_method=quantization_calibration_method,
)
]
return rules
@classmethod
def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]:
"""Get quantization rules based on the config."""
if not getattr(config, "use_qwix_quantization", False):
return None
quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax")
match config.quantization:
case "int8":
return qwix.QtProvider(cls.get_basic_config(jnp.int8))
case "fp8":
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn))
case "fp8_full":
return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method))
return None
@classmethod
def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh):
"""Quantizes the transformer model."""
q_rules = cls.get_qt_provider(config)
if not q_rules:
return model
max_logging.log("Quantizing transformer with Qwix.")
batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32)
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
model_inputs = (latents, timesteps, prompt_embeds)
with mesh:
quantized_model = qwix.quantize_model(model, q_rules, *model_inputs)
max_logging.log("Qwix Quantization complete.")
return quantized_model
@classmethod
def load_transformer(
cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None
):
with mesh:
wan_transformer = create_sharded_logical_transformer(
devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint
)
return wan_transformer
@classmethod
def load_scheduler(cls, config):
scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="scheduler",
flow_shift=config.flow_shift, # 5.0 for 720p, 3.0 for 480p
)
return scheduler, scheduler_state
@classmethod
def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True):
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
rng = jax.random.key(config.seed)
rngs = nnx.Rngs(rng)
transformer = None
tokenizer = None
scheduler = None
scheduler_state = None
text_encoder = None
if not vae_only:
if load_transformer:
with mesh:
transformer = cls.load_transformer(
devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint
)
text_encoder = cls.load_text_encoder(config=config)
tokenizer = cls.load_tokenizer(config=config)
scheduler, scheduler_state = cls.load_scheduler(config=config)
with mesh:
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
return WanPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=wan_vae,
vae_cache=vae_cache,
scheduler=scheduler,
scheduler_state=scheduler_state,
devices_array=devices_array,
mesh=mesh,
config=config,
)
@classmethod
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
rng = jax.random.key(config.seed)
rngs = nnx.Rngs(rng)
transformer = None
tokenizer = None
scheduler = None
scheduler_state = None
text_encoder = None
if not vae_only:
if load_transformer:
with mesh:
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
text_encoder = cls.load_text_encoder(config=config)
tokenizer = cls.load_tokenizer(config=config)
scheduler, scheduler_state = cls.load_scheduler(config=config)
with mesh:
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
pipeline = WanPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=wan_vae,
vae_cache=vae_cache,
scheduler=scheduler,
scheduler_state=scheduler_state,
devices_array=devices_array,
mesh=mesh,
config=config,
)
pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh)
return pipeline
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
prompt_embeds: jax.Array = None,
negative_prompt_embeds: jax.Array = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
)
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype)
if negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
)
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self,
batch_size: int,
vae_scale_factor_temporal: int,
vae_scale_factor_spatial: int,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_channels_latents: int = 16,
):
rng = jax.random.key(self.config.seed)
num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // vae_scale_factor_spatial,
int(width) // vae_scale_factor_spatial,
)
latents = jax.random.normal(rng, shape=shape, dtype=self.config.weights_dtype)
return latents
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
num_videos_per_prompt: Optional[int] = 1,
max_sequence_length: int = 512,
latents: jax.Array = None,
prompt_embeds: jax.Array = None,
negative_prompt_embeds: jax.Array = None,
vae_only: bool = False,
):
if not vae_only:
if num_frames % self.vae_scale_factor_temporal != 1:
max_logging.log(
f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
prompt = [prompt]
batch_size = len(prompt)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
max_sequence_length=max_sequence_length,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
num_channel_latents = self.transformer.config.in_channels
if latents is None:
latents = self.prepare_latents(
batch_size=batch_size,
vae_scale_factor_temporal=self.vae_scale_factor_temporal,
vae_scale_factor_spatial=self.vae_scale_factor_spatial,
height=height,
width=width,
num_frames=num_frames,
num_channels_latents=num_channel_latents,
)
data_sharding = NamedSharding(self.mesh, P())
# Using global_batch_size_to_train_on so not to create more config variables
if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0:
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
latents = jax.device_put(latents, data_sharding)
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
scheduler_state = self.scheduler.set_timesteps(
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
)
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)
p_run_inference = partial(
run_inference,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
scheduler=self.scheduler,
scheduler_state=scheduler_state,
num_transformer_layers=self.transformer.config.num_layers,
)
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
latents = p_run_inference(
graphdef=graphdef,
sharded_state=state,
rest_of_state=rest_of_state,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
latents = latents / latents_std + latents_mean
latents = latents.astype(self.config.weights_dtype)
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
video = self.vae.decode(latents, self.vae_cache)[0]
video = jnp.transpose(video, (0, 4, 1, 2, 3))
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
video = self.video_processor.postprocess_video(video, output_type="np")
return video
@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale"))
def transformer_forward_pass(
graphdef,
sharded_state,
rest_of_state,
latents,
timestep,
prompt_embeds,
do_classifier_free_guidance,
guidance_scale,
):
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds)
if do_classifier_free_guidance:
bsz = latents.shape[0] // 2
noise_uncond = noise_pred[bsz:]
noise_pred = noise_pred[:bsz]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
latents = latents[:bsz]
return noise_pred, latents
def run_inference(
graphdef,
sharded_state,
rest_of_state,
latents: jnp.array,
prompt_embeds: jnp.array,
negative_prompt_embeds: jnp.array,
guidance_scale: float,
num_inference_steps: int,
scheduler: FlaxUniPCMultistepScheduler,
num_transformer_layers: int,
scheduler_state,
):
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
if do_classifier_free_guidance:
latents = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, latents.shape[0])
noise_pred, latents = transformer_forward_pass(
graphdef,
sharded_state,
rest_of_state,
latents,
timestep,
prompt_embeds,
do_classifier_free_guidance=do_classifier_free_guidance,
guidance_scale=guidance_scale,
)
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents