Skip to content

Commit 8a49410

Browse files
committed
Remove vae_cache
1 parent a33e288 commit 8a49410

1 file changed

Lines changed: 5 additions & 10 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated
2929
from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae
3030
from ...models.wan.transformers.transformer_wan import WanModel
31-
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache
31+
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan
3232
from maxdiffusion.video_processor import VideoProcessor
3333
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
3434
from transformers import AutoTokenizer, UMT5EncoderModel
@@ -196,7 +196,6 @@ def __init__(
196196
text_encoder: UMT5EncoderModel,
197197
transformer: WanModel,
198198
vae: AutoencoderKLWan,
199-
vae_cache: AutoencoderKLWanCache,
200199
scheduler: FlaxUniPCMultistepScheduler,
201200
scheduler_state: UniPCMultistepSchedulerState,
202201
devices_array: np.array,
@@ -207,7 +206,6 @@ def __init__(
207206
self.text_encoder = text_encoder
208207
self.transformer = transformer
209208
self.vae = vae
210-
self.vae_cache = vae_cache
211209
self.scheduler = scheduler
212210
self.scheduler_state = scheduler_state
213211
self.devices_array = devices_array
@@ -275,8 +273,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
275273
state = nnx.from_flat_state(state)
276274

277275
wan_vae = nnx.merge(graphdef, state)
278-
vae_cache = AutoencoderKLWanCache(wan_vae)
279-
return wan_vae, vae_cache
276+
return wan_vae
280277

281278
@classmethod
282279
def get_basic_config(cls, dtype, config: HyperParameters):
@@ -396,14 +393,13 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
396393
scheduler, scheduler_state = cls.load_scheduler(config=config)
397394

398395
with mesh:
399-
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
396+
wan_vae = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
400397

401398
return WanPipeline(
402399
tokenizer=tokenizer,
403400
text_encoder=text_encoder,
404401
transformer=transformer,
405402
vae=wan_vae,
406-
vae_cache=vae_cache,
407403
scheduler=scheduler,
408404
scheduler_state=scheduler_state,
409405
devices_array=devices_array,
@@ -433,14 +429,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
433429
scheduler, scheduler_state = cls.load_scheduler(config=config)
434430

435431
with mesh:
436-
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
432+
wan_vae = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
437433

438434
pipeline = WanPipeline(
439435
tokenizer=tokenizer,
440436
text_encoder=text_encoder,
441437
transformer=transformer,
442438
vae=wan_vae,
443-
vae_cache=vae_cache,
444439
scheduler=scheduler,
445440
scheduler_state=scheduler_state,
446441
devices_array=devices_array,
@@ -629,7 +624,7 @@ def __call__(
629624
latents = latents.astype(jnp.float32)
630625

631626
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
632-
video = self.vae.decode(latents, self.vae_cache)[0]
627+
video = self.vae.decode(latents, return_dict=False)[0]
633628

634629
video = jnp.transpose(video, (0, 4, 1, 2, 3))
635630
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)

0 commit comments

Comments
 (0)