Skip to content

Commit 18d167c

Browse files
committed
nnx.jit with python for loop
1 parent 705b813 commit 18d167c

1 file changed

Lines changed: 23 additions & 25 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,19 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
146146
else:
147147
x_padded = x
148148

149-
if self.mesh is not None:
150-
# Shard height dimension (index 2) along 'context' axis
151-
# Shape is (Batch, Time, Height, Width, Channels)
152-
# We only shard if the dimension is divisible by the mesh size to avoid XLA errors
153-
if x_padded.shape[2] % self.mesh.shape["context"] == 0:
154-
sharding = NamedSharding(self.mesh, P(None, None, "context", None, None))
155-
x_padded = jax.lax.with_sharding_constraint(x_padded, sharding)
149+
if self.mesh is not None and "context" in self.mesh.axis_names:
150+
height = x_padded.shape[2]
151+
width = x_padded.shape[3]
152+
num_context_devices = self.mesh.shape["context"]
153+
154+
shard_axis = "context" if (height % num_context_devices == 0) else None
155+
shard_width_axis = None
156+
if shard_axis is None and width % num_context_devices == 0:
157+
shard_width_axis = "context"
158+
159+
x_padded = jax.lax.with_sharding_constraint(
160+
x_padded, jax.sharding.PartitionSpec("data", None, shard_axis, shard_width_axis, None)
161+
)
156162

157163
out = self.conv(x_padded)
158164
return out
@@ -769,7 +775,6 @@ def __init__(
769775
precision=precision,
770776
)
771777

772-
@nnx.jit(static_argnames="feat_idx")
773778
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
774779
if feat_cache is not None:
775780
idx = feat_idx
@@ -918,7 +923,6 @@ def __init__(
918923
precision=precision,
919924
)
920925

921-
@nnx.jit(static_argnames="feat_idx")
922926
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
923927
if feat_cache is not None:
924928
idx = feat_idx
@@ -1113,8 +1117,8 @@ def __init__(
11131117
precision=precision,
11141118
)
11151119

1120+
@nnx.jit
11161121
def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
1117-
feat_cache.init_cache()
11181122
if x.shape[-1] != 3:
11191123
# reshape channel last for JAX
11201124
x = jnp.transpose(x, (0, 2, 3, 4, 1))
@@ -1136,29 +1140,27 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11361140
)
11371141
out = jnp.concatenate([out, out_], axis=1)
11381142

1139-
# Update back to the wrapper object if needed, but for result we use local vars
1140-
feat_cache._enc_feat_map = enc_feat_map
1141-
11421143
enc = self.quant_conv(out)
11431144
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
11441145
enc = jnp.concatenate([mu, logvar], axis=-1)
1145-
feat_cache.init_cache()
11461146
return enc
11471147

11481148
def encode(
11491149
self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True
11501150
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
11511151
"""Encode video into latent distribution."""
1152+
feat_cache.init_cache()
11521153
h = self._encode(x, feat_cache)
1154+
feat_cache.init_cache()
11531155
posterior = WanDiagonalGaussianDistribution(h)
11541156
if not return_dict:
11551157
return (posterior,)
11561158
return FlaxAutoencoderKLOutput(latent_dist=posterior)
11571159

1160+
@nnx.jit
11581161
def _decode(
1159-
self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True
1160-
) -> Union[FlaxDecoderOutput, jax.Array]:
1161-
feat_cache.init_cache()
1162+
self, z: jax.Array, feat_cache: AutoencoderKLWanCache
1163+
) -> jax.Array:
11621164
iter_ = z.shape[1]
11631165
x = self.post_quant_conv(z)
11641166

@@ -1188,14 +1190,8 @@ def _decode(
11881190
fm4 = jnp.expand_dims(fm4, axis=axis)
11891191
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
11901192

1191-
feat_cache._feat_map = dec_feat_map
1192-
11931193
out = jnp.clip(out, min=-1.0, max=1.0)
1194-
feat_cache.init_cache()
1195-
if not return_dict:
1196-
return (out,)
1197-
1198-
return FlaxDecoderOutput(sample=out)
1194+
return out
11991195

12001196
def decode(
12011197
self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True
@@ -1204,7 +1200,9 @@ def decode(
12041200
# reshape channel last for JAX
12051201
z = jnp.transpose(z, (0, 2, 3, 4, 1))
12061202
assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}"
1207-
decoded = self._decode(z, feat_cache).sample
1203+
feat_cache.init_cache()
1204+
decoded = self._decode(z, feat_cache)
1205+
feat_cache.init_cache()
12081206
if not return_dict:
12091207
return (decoded,)
12101208
return FlaxDecoderOutput(sample=decoded)

0 commit comments

Comments
 (0)