2828from ...max_utils import get_flash_block_sizes , get_precision , device_put_replicated
2929from ...models .wan .wan_utils import load_wan_transformer , load_wan_vae
3030from ...models .wan .transformers .transformer_wan import WanModel
31- from ...models .wan .autoencoder_kl_wan import AutoencoderKLWan , AutoencoderKLWanCache
31+ from ...models .wan .autoencoder_kl_wan import AutoencoderKLWan
3232from maxdiffusion .video_processor import VideoProcessor
3333from ...schedulers .scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler , UniPCMultistepSchedulerState
3434from transformers import AutoTokenizer , UMT5EncoderModel
@@ -196,7 +196,6 @@ def __init__(
196196 text_encoder : UMT5EncoderModel ,
197197 transformer : WanModel ,
198198 vae : AutoencoderKLWan ,
199- vae_cache : AutoencoderKLWanCache ,
200199 scheduler : FlaxUniPCMultistepScheduler ,
201200 scheduler_state : UniPCMultistepSchedulerState ,
202201 devices_array : np .array ,
@@ -207,7 +206,6 @@ def __init__(
207206 self .text_encoder = text_encoder
208207 self .transformer = transformer
209208 self .vae = vae
210- self .vae_cache = vae_cache
211209 self .scheduler = scheduler
212210 self .scheduler_state = scheduler_state
213211 self .devices_array = devices_array
@@ -275,8 +273,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
275273 state = nnx .from_flat_state (state )
276274
277275 wan_vae = nnx .merge (graphdef , state )
278- vae_cache = AutoencoderKLWanCache (wan_vae )
279- return wan_vae , vae_cache
276+ return wan_vae
280277
281278 @classmethod
282279 def get_basic_config (cls , dtype , config : HyperParameters ):
@@ -396,14 +393,13 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
396393 scheduler , scheduler_state = cls .load_scheduler (config = config )
397394
398395 with mesh :
399- wan_vae , vae_cache = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
396+ wan_vae = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
400397
401398 return WanPipeline (
402399 tokenizer = tokenizer ,
403400 text_encoder = text_encoder ,
404401 transformer = transformer ,
405402 vae = wan_vae ,
406- vae_cache = vae_cache ,
407403 scheduler = scheduler ,
408404 scheduler_state = scheduler_state ,
409405 devices_array = devices_array ,
@@ -433,14 +429,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
433429 scheduler , scheduler_state = cls .load_scheduler (config = config )
434430
435431 with mesh :
436- wan_vae , vae_cache = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
432+ wan_vae = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
437433
438434 pipeline = WanPipeline (
439435 tokenizer = tokenizer ,
440436 text_encoder = text_encoder ,
441437 transformer = transformer ,
442438 vae = wan_vae ,
443- vae_cache = vae_cache ,
444439 scheduler = scheduler ,
445440 scheduler_state = scheduler_state ,
446441 devices_array = devices_array ,
@@ -629,7 +624,7 @@ def __call__(
629624 latents = latents .astype (jnp .float32 )
630625
631626 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
632- video = self .vae .decode (latents , self . vae_cache )[0 ]
627+ video = self .vae .decode (latents , return_dict = False )[0 ]
633628
634629 video = jnp .transpose (video , (0 , 4 , 1 , 2 , 3 ))
635630 video = jax .experimental .multihost_utils .process_allgather (video , tiled = True )
0 commit comments