@@ -556,17 +556,17 @@ def vae_encode(video, wan_vae, vae_cache, key):
556556 # This replaces random params with the model.
557557 params = load_wan_vae (config .pretrained_model_name_or_path , params , "cpu" )
558558 params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), params )
559+ # Transfer params to TPU device before merging to fix device mismatch
560+ tpu_device = jax .devices ("tpu" )[0 ] if jax .devices ("tpu" ) else jax .devices ()[0 ]
561+ params = jax .device_put (params , tpu_device )
559562 wan_vae = nnx .merge (graphdef , params )
560563
561564 p_vae_encode = functools .partial (vae_encode , wan_vae = wan_vae , vae_cache = vae_cache , key = key )
562565 original_video_shape = original_video .shape
566+ latent = p_vae_encode (original_video )
563567
564- # Execute VAE operations within mesh context to match sharding constraints
565- with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
566- latent = p_vae_encode (original_video )
567-
568- jitted_decode = functools .partial (wan_vae .decode , feat_cache = vae_cache , return_dict = False )
569- video = jitted_decode (latent )[0 ]
568+ jitted_decode = functools .partial (wan_vae .decode , feat_cache = vae_cache , return_dict = False )
569+ video = jitted_decode (latent )[0 ]
570570 video = jnp .transpose (video , (0 , 4 , 1 , 2 , 3 ))
571571 assert video .shape == original_video_shape
572572
0 commit comments