Skip to content

Commit 983841d

Browse files
committed
full refactor
1 parent 876456b commit 983841d

1 file changed

Lines changed: 30 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,24 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
11181118
return x, new_cache
11191119

11201120

1121+
class AutoencoderKLWanCache:
1122+
"""
1123+
Cache management for WAN VAE.
1124+
1125+
Note: With jax.lax.scan architecture, cache is managed internally.
1126+
This class exists for API compatibility with existing pipelines but doesn't
1127+
actually store persistent cache anymore.
1128+
"""
1129+
1130+
def __init__(self, module):
1131+
self.module = module
1132+
# No persistent cache needed with jax.lax.scan
1133+
1134+
def clear_cache(self):
1135+
"""No-op for API compatibility. Cache is created fresh for each encode/decode call."""
1136+
pass
1137+
1138+
11211139
class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin):
11221140
def __init__(
11231141
self,
@@ -1192,11 +1210,16 @@ def __init__(
11921210
)
11931211

11941212
def encode(
1195-
self, x: jax.Array, return_dict: bool = True
1213+
self, x: jax.Array, feat_cache: Optional[AutoencoderKLWanCache] = None, return_dict: bool = True
11961214
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
11971215
"""
11981216
Encode video using jax.lax.scan for temporal iteration.
11991217
This enables proper JIT compilation while managing memory efficiently.
1218+
1219+
Args:
1220+
x: Input video tensor
1221+
feat_cache: Cache object (for API compatibility, not used internally)
1222+
return_dict: Whether to return FlaxAutoencoderKLOutput or tuple
12001223
"""
12011224
if x.shape[-1] != 3:
12021225
x = jnp.transpose(x, (0, 2, 3, 4, 1))
@@ -1230,11 +1253,16 @@ def scan_fn(carry, input_slice):
12301253
return FlaxAutoencoderKLOutput(latent_dist=posterior)
12311254

12321255
def decode(
1233-
self, z: jax.Array, return_dict: bool = True
1256+
self, z: jax.Array, feat_cache: Optional[AutoencoderKLWanCache] = None, return_dict: bool = True
12341257
) -> Union[FlaxDecoderOutput, jax.Array]:
12351258
"""
12361259
Decode latents using jax.lax.scan for temporal iteration.
12371260
This enables proper JIT compilation while managing memory efficiently.
1261+
1262+
Args:
1263+
z: Latent tensor
1264+
feat_cache: Cache object (for API compatibility, not used internally)
1265+
return_dict: Whether to return FlaxDecoderOutput or tuple
12381266
"""
12391267
if z.shape[-1] != self.z_dim:
12401268
z = jnp.transpose(z, (0, 2, 3, 4, 1))

0 commit comments

Comments
 (0)