Skip to content

Commit 705b813

Browse files
committed
spatial sharding added
1 parent 1e1058a commit 705b813

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import jax.numpy as jnp
2222
from jax import tree_util
2323
from flax import nnx
24+
from jax.sharding import NamedSharding
25+
from jax.sharding import PartitionSpec as P
2426
from ...configuration_utils import ConfigMixin
2527
from ..modeling_flax_utils import FlaxModelMixin, get_activation
2628
from ... import common_types
@@ -144,6 +146,14 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
144146
else:
145147
x_padded = x
146148

149+
if self.mesh is not None:
150+
# Shard height dimension (index 2) along 'context' axis
151+
# Shape is (Batch, Time, Height, Width, Channels)
152+
# We only shard if the dimension is divisible by the mesh size to avoid XLA errors
153+
if x_padded.shape[2] % self.mesh.shape["context"] == 0:
154+
sharding = NamedSharding(self.mesh, P(None, None, "context", None, None))
155+
x_padded = jax.lax.with_sharding_constraint(x_padded, sharding)
156+
147157
out = self.conv(x_padded)
148158
return out
149159

0 commit comments

Comments
 (0)