@@ -99,10 +99,10 @@ def __init__(
9999 self .mesh = mesh
100100
101101 # Weight sharding (Kernel is sharded along output channels)
102- num_fsdp_devices = mesh .shape ["fsdp " ]
102+ num_fsdp_devices = mesh .shape ["vae_spatial " ]
103103 kernel_sharding = (None , None , None , None , None )
104104 if out_channels % num_fsdp_devices == 0 :
105- kernel_sharding = (None , None , None , None , "fsdp " )
105+ kernel_sharding = (None , None , None , None , "vae_spatial " )
106106
107107 self .conv = nnx .Conv (
108108 in_features = in_channels ,
@@ -121,7 +121,7 @@ def __init__(
121121 def __call__ (self , x : jax .Array , cache_x : Optional [jax .Array ] = None , idx = - 1 ) -> jax .Array :
122122 # Sharding Width (index 3)
123123 # Spec: (Batch, Time, Height, Width, Channels)
124- spatial_sharding = NamedSharding (self .mesh , P (None , None , None , "fsdp " , None ))
124+ spatial_sharding = NamedSharding (self .mesh , P (None , None , None , "vae_spatial " , None ))
125125 x = jax .lax .with_sharding_constraint (x , spatial_sharding )
126126
127127 current_padding = list (self ._causal_padding )
@@ -1098,7 +1098,7 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
10981098 iter_ = 1 + (t - 1 ) // 4
10991099 enc_feat_map = feat_cache ._enc_feat_map
11001100
1101- spatial_sharding = NamedSharding (self .mesh , P (None , None , None , "fsdp " , None ))
1101+ spatial_sharding = NamedSharding (self .mesh , P (None , None , None , "vae_spatial " , None ))
11021102
11031103 # First iteration (i=0): size 1
11041104 chunk_0 = x [:, :1 , ...]
@@ -1180,7 +1180,7 @@ def _decode(
11801180
11811181 dec_feat_map = feat_cache ._feat_map
11821182 # NamedSharding for the Width axis (axis 3)
1183- spatial_sharding = NamedSharding (self .mesh , P (None , None , None , "fsdp " , None ))
1183+ spatial_sharding = NamedSharding (self .mesh , P (None , None , None , "vae_spatial " , None ))
11841184
11851185 # First chunk (i=0)
11861186 chunk_in_0 = jax .lax .with_sharding_constraint (x [:, 0 :1 , ...], spatial_sharding )
@@ -1264,4 +1264,4 @@ def decode(
12641264 decoded = self ._decode (z , feat_cache ).sample
12651265 if not return_dict :
12661266 return (decoded ,)
1267- return FlaxDecoderOutput (sample = decoded )
1267+ return FlaxDecoderOutput (sample = decoded )
0 commit comments