Skip to content

Commit 919bba3

Browse files
committed
trying way2
1 parent 0400c2c commit 919bba3

1 file changed

Lines changed: 14 additions & 21 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -119,27 +119,6 @@ def __init__(
119119
)
120120

121121
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array:
122-
# OPTIMIZATION: Spatial Partitioning during execution
123-
if self.mesh is not None and "context" in self.mesh.axis_names:
124-
height = x.shape[2]
125-
width = x.shape[3]
126-
num_context_devices = self.mesh.shape["context"]
127-
128-
shard_axis = "context" if (height % num_context_devices == 0) else None
129-
shard_width_axis = None
130-
if shard_axis is None and width % num_context_devices == 0:
131-
shard_width_axis = "context"
132-
133-
x = jax.lax.with_sharding_constraint(
134-
x, jax.sharding.PartitionSpec("data", None, shard_axis, shard_width_axis, None)
135-
)
136-
137-
# Debug logging
138-
if shard_axis or shard_width_axis:
139-
jax.debug.print(
140-
"Spatial sharding applied: height_axis={}, width_axis={} for shape {}",
141-
shard_axis, shard_width_axis, x.shape
142-
)
143122

144123
current_padding = list(self._causal_padding) # Mutable copy
145124
padding_needed = self._depth_padding_before
@@ -165,6 +144,20 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
165144
x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0)
166145
else:
167146
x_padded = x
147+
148+
if self.mesh is not None and "context" in self.mesh.axis_names:
149+
height = x_padded.shape[2]
150+
width = x_padded.shape[3]
151+
num_context_devices = self.mesh.shape["context"]
152+
153+
shard_axis = "context" if (height % num_context_devices == 0) else None
154+
shard_width_axis = None
155+
if shard_axis is None and width % num_context_devices == 0:
156+
shard_width_axis = "context"
157+
158+
x_padded = jax.lax.with_sharding_constraint(
159+
x_padded, jax.sharding.PartitionSpec("data", None, shard_axis, shard_width_axis, None)
160+
)
168161

169162
out = self.conv(x_padded)
170163
return out

0 commit comments

Comments
 (0)