Skip to content

Commit 0400c2c

Browse files
committed
spatial sharding
1 parent 6e3b58b commit 0400c2c

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,28 @@ def __init__(
119119
)
120120

121121
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array:
122+
# OPTIMIZATION: Spatial Partitioning during execution
123+
if self.mesh is not None and "context" in self.mesh.axis_names:
124+
height = x.shape[2]
125+
width = x.shape[3]
126+
num_context_devices = self.mesh.shape["context"]
127+
128+
shard_axis = "context" if (height % num_context_devices == 0) else None
129+
shard_width_axis = None
130+
if shard_axis is None and width % num_context_devices == 0:
131+
shard_width_axis = "context"
132+
133+
x = jax.lax.with_sharding_constraint(
134+
x, jax.sharding.PartitionSpec("data", None, shard_axis, shard_width_axis, None)
135+
)
136+
137+
# Debug logging
138+
if shard_axis or shard_width_axis:
139+
jax.debug.print(
140+
"Spatial sharding applied: height_axis={}, width_axis={} for shape {}",
141+
shard_axis, shard_width_axis, x.shape
142+
)
143+
122144
current_padding = list(self._causal_padding) # Mutable copy
123145
padding_needed = self._depth_padding_before
124146

0 commit comments

Comments
 (0)