Skip to content

Commit 05e1b28

Browse files
committed
Only batch sharding added
1 parent e8bdd82 commit 05e1b28

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from ..modeling_flax_utils import FlaxModelMixin, get_activation
2525
from ... import common_types
2626
from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput)
27+
from jax.sharding import PartitionSpec
28+
from jax.lax import with_sharding_constraint
2729

2830
BlockSizes = common_types.BlockSizes
2931

@@ -117,6 +119,12 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
117119
x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
118120
else:
119121
x_padded = x
122+
123+
if self.mesh is not None:
124+
# (B, D, H, W, C)
125+
if x_padded.shape[0] % self.mesh.shape['data'] == 0:
126+
x_padded = with_sharding_constraint(x_padded, PartitionSpec('data', None, None, None, None))
127+
120128
out = self.conv(x_padded)
121129
return out
122130

0 commit comments

Comments
 (0)