Skip to content

Commit f49f5d9

Browse files
committed
added nnx.jit decorator to encode and decode
1 parent 0dd0bf0 commit f49f5d9

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,7 @@ def __init__(
11931193
precision=precision,
11941194
)
11951195

1196+
@nnx.jit
11961197
def encode(
11971198
self, x: jax.Array, return_dict: bool = True
11981199
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
@@ -1225,6 +1226,7 @@ def scan_fn(carry, input_slice):
12251226
return (posterior,)
12261227
return FlaxAutoencoderKLOutput(latent_dist=posterior)
12271228

1229+
@nnx.jit
12281230
def decode(
12291231
self, z: jax.Array, return_dict: bool = True
12301232
) -> Union[FlaxDecoderOutput, jax.Array]:

0 commit comments

Comments
 (0)