@@ -1258,10 +1258,8 @@ def __init__(
12581258 precision = precision ,
12591259 )
12601260
1261- @nnx .jit
1262- def encode (
1263- self , x : jax .Array , return_dict : bool = True
1264- ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
1261+ def _encode_jit (self , x : jax .Array ) -> jax .Array :
1262+ """Core computation part to be JIT-compiled."""
12651263 if x .shape [- 1 ] != 3 :
12661264 # reshape channel last for JAX
12671265 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
@@ -1283,14 +1281,26 @@ def scan_fn(carry, input_slice):
12831281 encoded = jnp .swapaxes (encoded_frames , 0 , 1 )
12841282 enc , _ = self .quant_conv (encoded )
12851283
1286- mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
1287- h = jnp .concatenate ([mu , logvar ], axis = - 1 )
1284+ # h contains the parameters for the distribution
1285+ h = enc # Or jnp.concatenate([mu, logvar], axis=-1) as originally
1286+ return h
1287+ _encode_compiled = nnx .jit (_encode_jit )
1288+
1289+ def encode (
1290+ self , x : jax .Array , return_dict : bool = True
1291+ ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
1292+ """Encodes the input, returning standard distribution objects."""
1293+ # Call the compiled function to get JAX arrays
1294+ h = self ._encode_compiled (x )
12881295
1296+ # Create custom objects outside the JIT scope
12891297 posterior = FlaxDiagonalGaussianDistribution (h )
1298+
12901299 if not return_dict :
12911300 return (posterior ,)
12921301 return FlaxAutoencoderKLOutput (latent_dist = posterior )
12931302
1303+
12941304 @nnx .jit
12951305 def decode (
12961306 self , z : jax .Array , return_dict : bool = True
0 commit comments