Skip to content

Commit b0e4851

Browse files
committed
separating encode methods
1 parent 94cc109 commit b0e4851

1 file changed

Lines changed: 16 additions & 6 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,10 +1258,8 @@ def __init__(
12581258
precision=precision,
12591259
)
12601260

1261-
@nnx.jit
1262-
def encode(
1263-
self, x: jax.Array, return_dict: bool = True
1264-
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
1261+
def _encode_jit(self, x: jax.Array) -> jax.Array:
1262+
"""Core computation part to be JIT-compiled."""
12651263
if x.shape[-1] != 3:
12661264
# reshape channel last for JAX
12671265
x = jnp.transpose(x, (0, 2, 3, 4, 1))
@@ -1283,14 +1281,26 @@ def scan_fn(carry, input_slice):
12831281
encoded = jnp.swapaxes(encoded_frames, 0, 1)
12841282
enc, _ = self.quant_conv(encoded)
12851283

1286-
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
1287-
h = jnp.concatenate([mu, logvar], axis=-1)
1284+
# h contains the parameters for the distribution
1285+
h = enc # Or jnp.concatenate([mu, logvar], axis=-1) as originally
1286+
return h
1287+
_encode_compiled = nnx.jit(_encode_jit)
1288+
1289+
def encode(
1290+
self, x: jax.Array, return_dict: bool = True
1291+
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
1292+
"""Encodes the input, returning standard distribution objects."""
1293+
# Call the compiled function to get JAX arrays
1294+
h = self._encode_compiled(x)
12881295

1296+
# Create custom objects outside the JIT scope
12891297
posterior = FlaxDiagonalGaussianDistribution(h)
1298+
12901299
if not return_dict:
12911300
return (posterior,)
12921301
return FlaxAutoencoderKLOutput(latent_dist=posterior)
12931302

1303+
12941304
@nnx.jit
12951305
def decode(
12961306
self, z: jax.Array, return_dict: bool = True

0 commit comments

Comments
 (0)