|
22 | 22 | import jax |
23 | 23 | import jax.numpy as jnp |
24 | 24 | from flax import nnx |
| 25 | +from flax import linen as nn |
25 | 26 | from flax.linen import partitioning as nn_partitioning |
26 | 27 | from jax.sharding import Mesh |
27 | 28 | from .. import pyconfig |
28 | 29 | from ..max_utils import ( |
29 | 30 | create_device_mesh, |
| 31 | + device_put_replicated, |
30 | 32 | ) |
31 | 33 | import numpy as np |
32 | 34 | import unittest |
@@ -556,10 +558,18 @@ def vae_encode(video, wan_vae, vae_cache, key): |
556 | 558 | # This replaces random params with the model. |
557 | 559 | params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") |
558 | 560 | params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) |
559 | | - # Transfer params to TPU device before merging to fix device mismatch |
560 | | - tpu_device = jax.devices("tpu")[0] if jax.devices("tpu") else jax.devices()[0] |
561 | | - params = jax.device_put(params, tpu_device) |
562 | | - wan_vae = nnx.merge(graphdef, params) |
| 561 | + |
| 562 | + logical_state_spec = nnx.get_partition_spec(state) |
| 563 | + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) |
| 564 | + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) |
| 565 | + |
| 566 | + state_flat = dict(nnx.to_flat_state(state)) |
| 567 | + for path, val in flax.traverse_util.flatten_dict(params).items(): |
| 568 | + sharding = logical_state_sharding[path].get_value() |
| 569 | + state_flat[path][...] = device_put_replicated(val, sharding) |
| 570 | + state = nnx.from_flat_state(state_flat) |
| 571 | + |
| 572 | + wan_vae = nnx.merge(graphdef, state) |
563 | 573 |
|
564 | 574 | p_vae_encode = functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key) |
565 | 575 | original_video_shape = original_video.shape |
|
0 commit comments