Skip to content

Commit c5fb919

Browse files
committed
Changing data -> None
1 parent 919bba3 commit c5fb919

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
156156
shard_width_axis = "context"
157157

158158
x_padded = jax.lax.with_sharding_constraint(
159-
x_padded, jax.sharding.PartitionSpec("data", None, shard_axis, shard_width_axis, None)
159+
x_padded, jax.sharding.PartitionSpec(None, None, shard_axis, shard_width_axis, None)
160160
)
161161

162162
out = self.conv(x_padded)

0 commit comments

Comments
 (0)