Skip to content

Commit 5bbc5b9

Browse files
committed
fixing vae tests
1 parent ebeaa10 commit 5bbc5b9

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222
import jax
2323
import jax.numpy as jnp
2424
from flax import nnx
25+
from flax import linen as nn
2526
from flax.linen import partitioning as nn_partitioning
2627
from jax.sharding import Mesh
2728
from .. import pyconfig
2829
from ..max_utils import (
2930
create_device_mesh,
31+
device_put_replicated,
3032
)
3133
import numpy as np
3234
import unittest
@@ -556,10 +558,18 @@ def vae_encode(video, wan_vae, vae_cache, key):
556558
# This replaces random params with the model.
557559
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
558560
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
559-
# Transfer params to TPU device before merging to fix device mismatch
560-
tpu_device = jax.devices("tpu")[0] if jax.devices("tpu") else jax.devices()[0]
561-
params = jax.device_put(params, tpu_device)
562-
wan_vae = nnx.merge(graphdef, params)
561+
562+
logical_state_spec = nnx.get_partition_spec(state)
563+
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
564+
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
565+
566+
state_flat = dict(nnx.to_flat_state(state))
567+
for path, val in flax.traverse_util.flatten_dict(params).items():
568+
sharding = logical_state_sharding[path].get_value()
569+
state_flat[path][...] = device_put_replicated(val, sharding)
570+
state = nnx.from_flat_state(state_flat)
571+
572+
wan_vae = nnx.merge(graphdef, state)
563573

564574
p_vae_encode = functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key)
565575
original_video_shape = original_video.shape

0 commit comments

Comments
 (0)