2828from ...max_utils import get_flash_block_sizes , get_precision , device_put_replicated
2929from ...models .wan .wan_utils import load_wan_transformer , load_wan_vae
3030from ...models .wan .transformers .transformer_wan import WanModel
31- from ...models .wan .autoencoder_kl_wan import AutoencoderKLWan , AutoencoderKLWanCache
31+ from ...models .wan .autoencoder_kl_wan import AutoencoderKLWan
3232from maxdiffusion .video_processor import VideoProcessor
3333from ...schedulers .scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler , UniPCMultistepSchedulerState
3434from transformers import AutoTokenizer , UMT5EncoderModel
@@ -195,7 +195,6 @@ def __init__(
195195 low_noise_transformer : WanModel ,
196196 high_noise_transformer : WanModel ,
197197 vae : AutoencoderKLWan ,
198- vae_cache : AutoencoderKLWanCache ,
199198 scheduler : FlaxUniPCMultistepScheduler ,
200199 scheduler_state : UniPCMultistepSchedulerState ,
201200 devices_array : np .array ,
@@ -207,7 +206,6 @@ def __init__(
207206 self .low_noise_transformer = low_noise_transformer
208207 self .high_noise_transformer = high_noise_transformer
209208 self .vae = vae
210- self .vae_cache = vae_cache
211209 self .scheduler = scheduler
212210 self .scheduler_state = scheduler_state
213211 self .devices_array = devices_array
@@ -275,8 +273,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
275273 state = nnx .from_flat_state (state )
276274
277275 wan_vae = nnx .merge (graphdef , state )
278- vae_cache = AutoencoderKLWanCache (wan_vae )
279- return wan_vae , vae_cache
276+ return wan_vae
280277
281278 @classmethod
282279 def get_basic_config (cls , dtype , config : HyperParameters ):
@@ -396,15 +393,14 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
396393 scheduler , scheduler_state = cls .load_scheduler (config = config )
397394
398395 with mesh :
399- wan_vae , vae_cache = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
396+ wan_vae = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
400397
401398 return WanPipeline (
402399 tokenizer = tokenizer ,
403400 text_encoder = text_encoder ,
404401 low_noise_transformer = low_noise_transformer ,
405402 high_noise_transformer = high_noise_transformer ,
406403 vae = wan_vae ,
407- vae_cache = vae_cache ,
408404 scheduler = scheduler ,
409405 scheduler_state = scheduler_state ,
410406 devices_array = devices_array ,
@@ -435,15 +431,14 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform
435431 scheduler , scheduler_state = cls .load_scheduler (config = config )
436432
437433 with mesh :
438- wan_vae , vae_cache = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
434+ wan_vae = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
439435
440436 pipeline = WanPipeline (
441437 tokenizer = tokenizer ,
442438 text_encoder = text_encoder ,
443439 low_noise_transformer = low_noise_transformer ,
444440 high_noise_transformer = high_noise_transformer ,
445441 vae = wan_vae ,
446- vae_cache = vae_cache ,
447442 scheduler = scheduler ,
448443 scheduler_state = scheduler_state ,
449444 devices_array = devices_array ,
@@ -639,7 +634,7 @@ def __call__(
639634 latents = latents .astype (jnp .float32 )
640635
641636 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
642- video = self .vae .decode (latents , self . vae_cache )[0 ]
637+ video = self .vae .decode (latents , return_dict = False )[0 ]
643638
644639 video = jnp .transpose (video , (0 , 4 , 1 , 2 , 3 ))
645640 video = jax .experimental .multihost_utils .process_allgather (video , tiled = True )
0 commit comments