1616from typing import List , Union , Optional
1717from functools import partial
1818import numpy as np
19+ import math
1920import jax
2021import jax .numpy as jnp
2122from jax .sharding import Mesh , NamedSharding , PartitionSpec as P
@@ -201,6 +202,7 @@ def __init__(
201202 devices_array : np .array ,
202203 mesh : Mesh ,
203204 config : HyperParameters ,
205+ ** kwargs ,
204206 ):
205207 self .tokenizer = tokenizer
206208 self .text_encoder = text_encoder
@@ -213,6 +215,9 @@ def __init__(
213215 self .config = config
214216 self .model_name = config .model_name
215217
218+ self .vae_mesh = kwargs .get ("vae_mesh" , mesh )
219+ self .vae_logical_axis_rules = kwargs .get ("vae_logical_axis_rules" , config .logical_axis_rules )
220+
216221 self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
217222 self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
218223 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
@@ -236,7 +241,7 @@ def load_tokenizer(cls, config: HyperParameters):
236241 return tokenizer
237242
238243 @classmethod
239- def load_vae (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters ):
244+ def load_vae (cls , devices_array : np .array , mesh : Mesh , rngs : nnx .Rngs , config : HyperParameters , vae_logical_axis_rules : tuple = None ):
240245
241246 def create_model (rngs : nnx .Rngs , config : HyperParameters ):
242247 wan_vae = AutoencoderKLWan .from_config (
@@ -256,7 +261,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
256261
257262 # 2. retrieve the state shardings, mapping logical names to mesh axis names.
258263 logical_state_spec = nnx .get_partition_spec (state )
259- logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , config .logical_axis_rules )
264+ logical_rules = vae_logical_axis_rules if vae_logical_axis_rules is not None else config .logical_axis_rules
265+ logical_state_sharding = nn .logical_to_mesh_sharding (logical_state_spec , mesh , logical_rules )
260266 logical_state_sharding = dict (nnx .to_flat_state (logical_state_sharding ))
261267 params = state .to_pure_dict ()
262268 state = dict (nnx .to_flat_state (state ))
@@ -470,7 +476,7 @@ def _denormalize_latents(self, latents: jax.Array) -> jax.Array:
470476
471477 def _decode_latents_to_video (self , latents : jax .Array ) -> np .ndarray :
472478 """Decodes latents to video frames and postprocesses."""
473- with self .mesh , nn_partitioning .axis_rules (self .config . logical_axis_rules ):
479+ with self .vae_mesh , nn_partitioning .axis_rules (self .vae_logical_axis_rules ):
474480 video = self .vae .decode (latents , self .vae_cache )[0 ]
475481
476482 video = jnp .transpose (video , (0 , 4 , 1 , 2 , 3 ))
@@ -482,15 +488,49 @@ def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray:
482488 def _create_common_components (cls , config , vae_only = False ):
483489 devices_array = max_utils .create_device_mesh (config )
484490 mesh = Mesh (devices_array , config .mesh_axes )
491+
492+ vae_spatial = getattr (config , "vae_spatial" , - 1 )
493+ total_devices = math .prod (devices_array .shape )
494+
495+ if vae_spatial <= 0 :
496+ dp_size = mesh .shape .get ("data" , 1 )
497+ if dp_size == - 1 or dp_size == 0 :
498+ dp_size = 1
499+ vae_spatial = (2 * total_devices ) // dp_size
500+
501+ assert total_devices % vae_spatial == 0 , f"total devices ({ total_devices } ) must be a multiple of vae_spatial ({ vae_spatial } )"
502+
503+ flat_devices = devices_array .flatten ()
504+ vae_devices_array = flat_devices .reshape (total_devices // vae_spatial , vae_spatial )
505+
506+ vae_mesh = Mesh (vae_devices_array , ("redundant" , "vae_spatial" ))
507+ vae_mesh .vae_spatial_axis_name = "vae_spatial"
508+ max_logging .log (f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of { vae_spatial } ." )
509+
510+ # logical axis rules for VAE encoding/decoding
511+ vae_logical_axis_rules = (
512+ ("activation_batch" , "redundant" ),
513+ ("activation_length" , "vae_spatial" ),
514+ ("activation_heads" , None ),
515+ ("activation_kv_length" , None ),
516+ ("embed" , None ),
517+ ("heads" , None ),
518+ ("norm" , None ),
519+ ("conv_batch" , "redundant" ),
520+ ("out_channels" , "vae_spatial" ),
521+ ("conv_out" , "vae_spatial" )
522+ )
523+
485524 rng = jax .random .key (config .seed )
486525 rngs = nnx .Rngs (rng )
487526
488- with mesh :
489- wan_vae , vae_cache = cls .load_vae (devices_array = devices_array , mesh = mesh , rngs = rngs , config = config )
527+ with vae_mesh :
528+ wan_vae , vae_cache = cls .load_vae (devices_array = devices_array , mesh = vae_mesh , rngs = rngs , config = config , vae_logical_axis_rules = vae_logical_axis_rules )
490529
491530 components = {
492531 "vae" : wan_vae , "vae_cache" : vae_cache ,
493- "devices_array" : devices_array , "rngs" : rngs , "mesh" : mesh ,
532+ "devices_array" : devices_array , "rngs" : rngs , "mesh" : mesh , "vae_mesh" : vae_mesh ,
533+ "vae_logical_axis_rules" : vae_logical_axis_rules ,
494534 "tokenizer" : None , "text_encoder" : None , "scheduler" : None , "scheduler_state" : None
495535 }
496536
0 commit comments