Skip to content

Commit ebeaa10

Browse files
committed
move params to tpu device
1 parent bda990b commit ebeaa10

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -556,17 +556,17 @@ def vae_encode(video, wan_vae, vae_cache, key):
556556
# This replaces random params with the model.
557557
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
558558
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
559+
# Transfer params to TPU device before merging to fix device mismatch
560+
tpu_device = jax.devices("tpu")[0] if jax.devices("tpu") else jax.devices()[0]
561+
params = jax.device_put(params, tpu_device)
559562
wan_vae = nnx.merge(graphdef, params)
560563

561564
p_vae_encode = functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key)
562565
original_video_shape = original_video.shape
566+
latent = p_vae_encode(original_video)
563567

564-
# Execute VAE operations within mesh context to match sharding constraints
565-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
566-
latent = p_vae_encode(original_video)
567-
568-
jitted_decode = functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)
569-
video = jitted_decode(latent)[0]
568+
jitted_decode = functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)
569+
video = jitted_decode(latent)[0]
570570
video = jnp.transpose(video, (0, 4, 1, 2, 3))
571571
assert video.shape == original_video_shape
572572

0 commit comments

Comments
 (0)