@@ -146,13 +146,19 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
146146 else :
147147 x_padded = x
148148
149- if self .mesh is not None :
150- # Shard height dimension (index 2) along 'context' axis
151- # Shape is (Batch, Time, Height, Width, Channels)
152- # We only shard if the dimension is divisible by the mesh size to avoid XLA errors
153- if x_padded .shape [2 ] % self .mesh .shape ["context" ] == 0 :
154- sharding = NamedSharding (self .mesh , P (None , None , "context" , None , None ))
155- x_padded = jax .lax .with_sharding_constraint (x_padded , sharding )
149+ if self .mesh is not None and "context" in self .mesh .axis_names :
150+ height = x_padded .shape [2 ]
151+ width = x_padded .shape [3 ]
152+ num_context_devices = self .mesh .shape ["context" ]
153+
154+ shard_axis = "context" if (height % num_context_devices == 0 ) else None
155+ shard_width_axis = None
156+ if shard_axis is None and width % num_context_devices == 0 :
157+ shard_width_axis = "context"
158+
159+ x_padded = jax .lax .with_sharding_constraint (
160+ x_padded , jax .sharding .PartitionSpec ("data" , None , shard_axis , shard_width_axis , None )
161+ )
156162
157163 out = self .conv (x_padded )
158164 return out
@@ -769,7 +775,6 @@ def __init__(
769775 precision = precision ,
770776 )
771777
772- @nnx .jit (static_argnames = "feat_idx" )
773778 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = 0 ):
774779 if feat_cache is not None :
775780 idx = feat_idx
@@ -918,7 +923,6 @@ def __init__(
918923 precision = precision ,
919924 )
920925
921- @nnx .jit (static_argnames = "feat_idx" )
922926 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = 0 ):
923927 if feat_cache is not None :
924928 idx = feat_idx
@@ -1113,8 +1117,8 @@ def __init__(
11131117 precision = precision ,
11141118 )
11151119
1120+ @nnx .jit
11161121 def _encode (self , x : jax .Array , feat_cache : AutoencoderKLWanCache ):
1117- feat_cache .init_cache ()
11181122 if x .shape [- 1 ] != 3 :
11191123 # reshape channel last for JAX
11201124 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
@@ -1136,29 +1140,27 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11361140 )
11371141 out = jnp .concatenate ([out , out_ ], axis = 1 )
11381142
1139- # Update back to the wrapper object if needed, but for result we use local vars
1140- feat_cache ._enc_feat_map = enc_feat_map
1141-
11421143 enc = self .quant_conv (out )
11431144 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
11441145 enc = jnp .concatenate ([mu , logvar ], axis = - 1 )
1145- feat_cache .init_cache ()
11461146 return enc
11471147
11481148 def encode (
11491149 self , x : jax .Array , feat_cache : AutoencoderKLWanCache , return_dict : bool = True
11501150 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
11511151 """Encode video into latent distribution."""
1152+ feat_cache .init_cache ()
11521153 h = self ._encode (x , feat_cache )
1154+ feat_cache .init_cache ()
11531155 posterior = WanDiagonalGaussianDistribution (h )
11541156 if not return_dict :
11551157 return (posterior ,)
11561158 return FlaxAutoencoderKLOutput (latent_dist = posterior )
11571159
1160+ @nnx .jit
11581161 def _decode (
1159- self , z : jax .Array , feat_cache : AutoencoderKLWanCache , return_dict : bool = True
1160- ) -> Union [FlaxDecoderOutput , jax .Array ]:
1161- feat_cache .init_cache ()
1162+ self , z : jax .Array , feat_cache : AutoencoderKLWanCache
1163+ ) -> jax .Array :
11621164 iter_ = z .shape [1 ]
11631165 x = self .post_quant_conv (z )
11641166
@@ -1188,14 +1190,8 @@ def _decode(
11881190 fm4 = jnp .expand_dims (fm4 , axis = axis )
11891191 out = jnp .concatenate ([out , fm1 , fm3 , fm2 , fm4 ], axis = 1 )
11901192
1191- feat_cache ._feat_map = dec_feat_map
1192-
11931193 out = jnp .clip (out , min = - 1.0 , max = 1.0 )
1194- feat_cache .init_cache ()
1195- if not return_dict :
1196- return (out ,)
1197-
1198- return FlaxDecoderOutput (sample = out )
1194+ return out
11991195
12001196 def decode (
12011197 self , z : jax .Array , feat_cache : AutoencoderKLWanCache , return_dict : bool = True
@@ -1204,7 +1200,9 @@ def decode(
12041200 # reshape channel last for JAX
12051201 z = jnp .transpose (z , (0 , 2 , 3 , 4 , 1 ))
12061202 assert z .shape [- 1 ] == self .z_dim , f"Expected input shape (N, D, H, W, { self .z_dim } , got { z .shape } "
1207- decoded = self ._decode (z , feat_cache ).sample
1203+ feat_cache .init_cache ()
1204+ decoded = self ._decode (z , feat_cache )
1205+ feat_cache .init_cache ()
12081206 if not return_dict :
12091207 return (decoded ,)
12101208 return FlaxDecoderOutput (sample = decoded )
0 commit comments