Skip to content

Commit c236d56

Browse files
committed
Renaming VAE sharding axis to vae_spatial
1 parent e7cd3c4 commit c236d56

1 file changed

Lines changed: 46 additions & 6 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import List, Union, Optional
1717
from functools import partial
1818
import numpy as np
19+
import math
1920
import jax
2021
import jax.numpy as jnp
2122
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
@@ -201,6 +202,7 @@ def __init__(
201202
devices_array: np.array,
202203
mesh: Mesh,
203204
config: HyperParameters,
205+
**kwargs,
204206
):
205207
self.tokenizer = tokenizer
206208
self.text_encoder = text_encoder
@@ -213,6 +215,9 @@ def __init__(
213215
self.config = config
214216
self.model_name = config.model_name
215217

218+
self.vae_mesh = kwargs.get("vae_mesh", mesh)
219+
self.vae_logical_axis_rules = kwargs.get("vae_logical_axis_rules", config.logical_axis_rules)
220+
216221
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
217222
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
218223
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -236,7 +241,7 @@ def load_tokenizer(cls, config: HyperParameters):
236241
return tokenizer
237242

238243
@classmethod
239-
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
244+
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, vae_logical_axis_rules: tuple = None):
240245

241246
def create_model(rngs: nnx.Rngs, config: HyperParameters):
242247
wan_vae = AutoencoderKLWan.from_config(
@@ -256,7 +261,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
256261

257262
# 2. retrieve the state shardings, mapping logical names to mesh axis names.
258263
logical_state_spec = nnx.get_partition_spec(state)
259-
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
264+
logical_rules = vae_logical_axis_rules if vae_logical_axis_rules is not None else config.logical_axis_rules
265+
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, logical_rules)
260266
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
261267
params = state.to_pure_dict()
262268
state = dict(nnx.to_flat_state(state))
@@ -470,7 +476,7 @@ def _denormalize_latents(self, latents: jax.Array) -> jax.Array:
470476

471477
def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray:
472478
"""Decodes latents to video frames and postprocesses."""
473-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
479+
with self.vae_mesh, nn_partitioning.axis_rules(self.vae_logical_axis_rules):
474480
video = self.vae.decode(latents, self.vae_cache)[0]
475481

476482
video = jnp.transpose(video, (0, 4, 1, 2, 3))
@@ -482,15 +488,49 @@ def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray:
482488
def _create_common_components(cls, config, vae_only=False):
483489
devices_array = max_utils.create_device_mesh(config)
484490
mesh = Mesh(devices_array, config.mesh_axes)
491+
492+
vae_spatial = getattr(config, "vae_spatial", -1)
493+
total_devices = math.prod(devices_array.shape)
494+
495+
if vae_spatial <= 0:
496+
dp_size = mesh.shape.get("data", 1)
497+
if dp_size == -1 or dp_size == 0:
498+
dp_size = 1
499+
vae_spatial = (2 * total_devices) // dp_size
500+
501+
assert total_devices % vae_spatial == 0, f"total devices ({total_devices}) must be a multiple of vae_spatial ({vae_spatial})"
502+
503+
flat_devices = devices_array.flatten()
504+
vae_devices_array = flat_devices.reshape(total_devices // vae_spatial, vae_spatial)
505+
506+
vae_mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial"))
507+
vae_mesh.vae_spatial_axis_name = "vae_spatial"
508+
max_logging.log(f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of {vae_spatial}.")
509+
510+
# logical axis rules for VAE encoding/decoding
511+
vae_logical_axis_rules = (
512+
("activation_batch", "redundant"),
513+
("activation_length", "vae_spatial"),
514+
("activation_heads", None),
515+
("activation_kv_length", None),
516+
("embed", None),
517+
("heads", None),
518+
("norm", None),
519+
("conv_batch", "redundant"),
520+
("out_channels", "vae_spatial"),
521+
("conv_out", "vae_spatial")
522+
)
523+
485524
rng = jax.random.key(config.seed)
486525
rngs = nnx.Rngs(rng)
487526

488-
with mesh:
489-
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
527+
with vae_mesh:
528+
wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=vae_mesh, rngs=rngs, config=config, vae_logical_axis_rules=vae_logical_axis_rules)
490529

491530
components = {
492531
"vae": wan_vae, "vae_cache": vae_cache,
493-
"devices_array": devices_array, "rngs": rngs, "mesh": mesh,
532+
"devices_array": devices_array, "rngs": rngs, "mesh": mesh, "vae_mesh": vae_mesh,
533+
"vae_logical_axis_rules": vae_logical_axis_rules,
494534
"tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None
495535
}
496536

0 commit comments

Comments
 (0)