|
19 | 19 | import jax |
20 | 20 | import jax.numpy as jnp |
21 | 21 | from flax import nnx |
| 22 | +from jax.sharding import PartitionSpec as P |
22 | 23 | from ...configuration_utils import ConfigMixin |
23 | 24 | from ..modeling_flax_utils import FlaxModelMixin, get_activation |
24 | 25 | from ... import common_types |
@@ -116,6 +117,7 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> |
116 | 117 | x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) |
117 | 118 | else: |
118 | 119 | x_padded = x |
| 120 | + x_padded = jax.lax.with_sharding_constraint(x_padded, P(None, None, 'fsdp', None, None)) |
119 | 121 | out = self.conv(x_padded) |
120 | 122 | return out |
121 | 123 |
|
@@ -336,6 +338,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: |
336 | 338 | x = x.reshape(b, t * 2, h, w, c) |
337 | 339 | t = x.shape[1] |
338 | 340 | x = x.reshape(b * t, h, w, c) |
| 341 | + x = jax.lax.with_sharding_constraint(x, P(None, 'fsdp', None, None)) |
339 | 342 | x = self.resample(x) |
340 | 343 | h_new, w_new, c_new = x.shape[1:] |
341 | 344 | x = x.reshape(b, t, h_new, w_new, c_new) |
@@ -486,6 +489,8 @@ def __call__(self, x: jax.Array): |
486 | 489 | identity = x |
487 | 490 | batch_size, time, height, width, channels = x.shape |
488 | 491 |
|
| 492 | + x = jax.lax.with_sharding_constraint(x, P(None, None, 'fsdp', None, None)) |
| 493 | + |
489 | 494 | x = x.reshape(batch_size * time, height, width, channels) |
490 | 495 | x = self.norm(x) |
491 | 496 |
|
|
0 commit comments