Skip to content

Commit e3e05a5

Browse files
committed
Add spatial sharding
1 parent dc4a830 commit e3e05a5

1 file changed

Lines changed: 18 additions & 7 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax
2121
import jax.numpy as jnp
2222
from flax import nnx
23+
from jax.sharding import PartitionSpec
2324
from ...configuration_utils import ConfigMixin
2425
from ..modeling_flax_utils import FlaxModelMixin, get_activation
2526
from ... import common_types
@@ -57,6 +58,7 @@ def __init__(
5758
self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size")
5859
self.stride = _canonicalize_tuple(stride, 3, "stride")
5960
padding_tuple = _canonicalize_tuple(padding, 3, "padding")
61+
self.mesh = mesh
6062

6163
self._causal_padding = (
6264
(0, 0),
@@ -90,9 +92,22 @@ def __init__(
9092
)
9193

9294
def initialize_cache(self, batch_size, height, width, dtype):
93-
return jnp.zeros((batch_size, CACHE_T, height, width, self.conv.in_features), dtype=dtype)
95+
# Create zeros
96+
cache = jnp.zeros((batch_size, CACHE_T, height, width, self.conv.in_features), dtype=dtype)
97+
98+
# OPTIMIZATION: Spatial Partitioning on Initialization
99+
# If we don't shard here, JAX allocates the full 2.64GB per chip, causing OOM.
100+
if self.mesh is not None:
101+
# Shard along Height (axis 2). Axis spec: (Batch, Time, Height, Width, Channels)
102+
# "fsdp" axis usually corresponds to the data parallelism or spatial split in this context.
103+
cache = jax.lax.with_sharding_constraint(cache, PartitionSpec(None, None, "fsdp", None, None))
104+
return cache
94105

95106
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> Tuple[jax.Array, jax.Array]:
107+
# OPTIMIZATION: Spatial Partitioning during execution
108+
if self.mesh is not None:
109+
x = jax.lax.with_sharding_constraint(x, PartitionSpec(None, None, "fsdp", None, None))
110+
96111
current_padding = list(self._causal_padding)
97112

98113
if cache_x is not None:
@@ -174,7 +189,6 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
174189
nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
175190
)
176191
elif mode == "upsample3d":
177-
# 3D mode ALSO needs Sequential for the spatial part to match checkpoints
178192
self.resample = nnx.Sequential(
179193
WanUpsample(scale_factor=(2.0, 2.0), method="nearest"),
180194
nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
@@ -219,8 +233,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
219233

220234
b, t, h, w, c = x.shape
221235
x = x.reshape(b * t, h, w, c)
222-
x = self.upsample(x)
223-
x = self.conv(x)
236+
x = self.resample(x) # Sequential
224237
h_new, w_new, c_new = x.shape[1:]
225238
x = x.reshape(b, t, h_new, w_new, c_new)
226239

@@ -234,7 +247,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
234247
elif self.mode == "downsample3d":
235248
b, t, h, w, c = x.shape
236249
x = x.reshape(b * t, h, w, c)
237-
x, _ = self.resample(x, None) # Fixed: use self.resample not self.downsample_conv
250+
x, _ = self.resample(x, None) # ZeroPaddedConv2D
238251
h_new, w_new, c_new = x.shape[1:]
239252
x = x.reshape(b, t, h_new, w_new, c_new)
240253

@@ -532,8 +545,6 @@ def __init__(self, rngs: nnx.Rngs, base_dim: int = 96, z_dim: int = 16, dim_mult
532545
self.z_dim = z_dim
533546
self.temperal_downsample = temperal_downsample
534547
self.temporal_upsample = temperal_downsample[::-1]
535-
536-
# MISSING attributes added back
537548
self.latents_mean = latents_mean
538549
self.latents_std = latents_std
539550

0 commit comments

Comments
 (0)