Skip to content

Commit 46d392f

Browse files
committed
changes for spatial sharding
1 parent 7f50b2b commit 46d392f

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
import jax.numpy as jnp
2121
from flax import nnx
22+
from jax.sharding import PartitionSpec as P
2223
from ...configuration_utils import ConfigMixin
2324
from ..modeling_flax_utils import FlaxModelMixin, get_activation
2425
from ... import common_types
@@ -116,6 +117,7 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
116117
x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
117118
else:
118119
x_padded = x
120+
x_padded = jax.lax.with_sharding_constraint(x_padded, P(None, None, 'fsdp', None, None))
119121
out = self.conv(x_padded)
120122
return out
121123

@@ -336,6 +338,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
336338
x = x.reshape(b, t * 2, h, w, c)
337339
t = x.shape[1]
338340
x = x.reshape(b * t, h, w, c)
341+
x = jax.lax.with_sharding_constraint(x, P(None, 'fsdp', None, None))
339342
x = self.resample(x)
340343
h_new, w_new, c_new = x.shape[1:]
341344
x = x.reshape(b, t, h_new, w_new, c_new)
@@ -486,6 +489,8 @@ def __call__(self, x: jax.Array):
486489
identity = x
487490
batch_size, time, height, width, channels = x.shape
488491

492+
x = jax.lax.with_sharding_constraint(x, P(None, None, 'fsdp', None, None))
493+
489494
x = x.reshape(batch_size * time, height, width, channels)
490495
x = self.norm(x)
491496

0 commit comments

Comments
 (0)