Skip to content

Commit dd27d04

Browse files
committed
Refactor
1 parent 716bf6c commit dd27d04

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,7 @@ def __init__(
756756
precision=precision,
757757
)
758758

759+
@nnx.jit
759760
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
760761
if feat_cache is not None:
761762
idx = feat_idx
@@ -904,6 +905,7 @@ def __init__(
904905
precision=precision,
905906
)
906907

908+
@nnx.jit
907909
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
908910
if feat_cache is not None:
909911
idx = feat_idx
@@ -1151,7 +1153,6 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11511153
feat_cache.init_cache()
11521154
return enc
11531155

1154-
@nnx.jit
11551156
def encode(
11561157
self, x: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True
11571158
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
@@ -1204,7 +1205,6 @@ def _decode(
12041205

12051206
return FlaxDecoderOutput(sample=out)
12061207

1207-
@nnx.jit
12081208
def decode(
12091209
self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True
12101210
) -> Union[FlaxDecoderOutput, jax.Array]:

0 commit comments

Comments
 (0)