Skip to content

Commit f18ae2f

Browse files
committed
Changed none -> data
1 parent 705b813 commit f18ae2f

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
151151
# Shape is (Batch, Time, Height, Width, Channels)
152152
# We only shard if the dimension is divisible by the mesh size to avoid XLA errors
153153
if x_padded.shape[2] % self.mesh.shape["context"] == 0:
154-
sharding = NamedSharding(self.mesh, P(None, None, "context", None, None))
154+
sharding = NamedSharding(self.mesh, P("data", None, "context", None, None))
155155
x_padded = jax.lax.with_sharding_constraint(x_padded, sharding)
156156

157157
out = self.conv(x_padded)

0 commit comments

Comments
 (0)