Skip to content

Commit a33e288

Browse files
committed
Remove vae_cache
1 parent 203f79b commit a33e288

1 file changed

Lines changed: 5 additions & 10 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated
2929
from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae
3030
from ...models.wan.transformers.transformer_wan import WanModel
31-
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache
31+
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan
3232
from maxdiffusion.video_processor import VideoProcessor
3333
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
3434
from transformers import AutoTokenizer, UMT5EncoderModel
@@ -195,7 +195,6 @@ def __init__(
195195
low_noise_transformer: WanModel,
196196
high_noise_transformer: WanModel,
197197
vae: AutoencoderKLWan,
198-
vae_cache: AutoencoderKLWanCache,
199198
scheduler: FlaxUniPCMultistepScheduler,
200199
scheduler_state: UniPCMultistepSchedulerState,
201200
devices_array: np.array,
@@ -207,7 +206,6 @@ def __init__(
207206
self.low_noise_transformer = low_noise_transformer
208207
self.high_noise_transformer = high_noise_transformer
209208
self.vae = vae
210-
self.vae_cache = vae_cache
211209
self.scheduler = scheduler
212210
self.scheduler_state = scheduler_state
213211
self.devices_array = devices_array
@@ -275,8 +273,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
275273
state = nnx.from_flat_state(state)
276274

277275
wan_vae = nnx.merge(graphdef, state)
278-
vae_cache = AutoencoderKLWanCache(wan_vae)
279-
return wan_vae, vae_cache
276+
return wan_vae
280277

281278
@classmethod
282279
def get_basic_config(cls, dtype, config: HyperParameters):
@@ -396,15 +393,14 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
396393
scheduler, scheduler_state = cls.load_scheduler(config=config)
397394

398395
with mesh:
399-
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
396+
wan_vae = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
400397

401398
return WanPipeline(
402399
tokenizer=tokenizer,
403400
text_encoder=text_encoder,
404401
low_noise_transformer=low_noise_transformer,
405402
high_noise_transformer=high_noise_transformer,
406403
vae=wan_vae,
407-
vae_cache=vae_cache,
408404
scheduler=scheduler,
409405
scheduler_state=scheduler_state,
410406
devices_array=devices_array,
@@ -435,15 +431,14 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
435431
scheduler, scheduler_state = cls.load_scheduler(config=config)
436432

437433
with mesh:
438-
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
434+
wan_vae = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
439435

440436
pipeline = WanPipeline(
441437
tokenizer=tokenizer,
442438
text_encoder=text_encoder,
443439
low_noise_transformer=low_noise_transformer,
444440
high_noise_transformer=high_noise_transformer,
445441
vae=wan_vae,
446-
vae_cache=vae_cache,
447442
scheduler=scheduler,
448443
scheduler_state=scheduler_state,
449444
devices_array=devices_array,
@@ -639,7 +634,7 @@ def __call__(
639634
latents = latents.astype(jnp.float32)
640635

641636
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
642-
video = self.vae.decode(latents, self.vae_cache)[0]
637+
video = self.vae.decode(latents, return_dict=False)[0]
643638

644639
video = jnp.transpose(video, (0, 4, 1, 2, 3))
645640
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)

0 commit comments

Comments
 (0)