@@ -1118,6 +1118,24 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
11181118 return x , new_cache
11191119
11201120
1121+ class AutoencoderKLWanCache :
1122+ """
1123+ Cache management for WAN VAE.
1124+
1125+ Note: With jax.lax.scan architecture, cache is managed internally.
1126+ This class exists for API compatibility with existing pipelines but doesn't
1127+ actually store persistent cache anymore.
1128+ """
1129+
1130+ def __init__ (self , module ):
1131+ self .module = module
1132+ # No persistent cache needed with jax.lax.scan
1133+
1134+ def clear_cache (self ):
1135+ """No-op for API compatibility. Cache is created fresh for each encode/decode call."""
1136+ pass
1137+
1138+
11211139class AutoencoderKLWan (nnx .Module , FlaxModelMixin , ConfigMixin ):
11221140 def __init__ (
11231141 self ,
@@ -1192,11 +1210,16 @@ def __init__(
11921210 )
11931211
11941212 def encode (
1195- self , x : jax .Array , return_dict : bool = True
1213+ self , x : jax .Array , feat_cache : Optional [ AutoencoderKLWanCache ] = None , return_dict : bool = True
11961214 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
11971215 """
11981216 Encode video using jax.lax.scan for temporal iteration.
11991217 This enables proper JIT compilation while managing memory efficiently.
1218+
1219+ Args:
1220+ x: Input video tensor
1221+ feat_cache: Cache object (for API compatibility, not used internally)
1222+ return_dict: Whether to return FlaxAutoencoderKLOutput or tuple
12001223 """
12011224 if x .shape [- 1 ] != 3 :
12021225 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
@@ -1230,11 +1253,16 @@ def scan_fn(carry, input_slice):
12301253 return FlaxAutoencoderKLOutput (latent_dist = posterior )
12311254
12321255 def decode (
1233- self , z : jax .Array , return_dict : bool = True
1256+ self , z : jax .Array , feat_cache : Optional [ AutoencoderKLWanCache ] = None , return_dict : bool = True
12341257 ) -> Union [FlaxDecoderOutput , jax .Array ]:
12351258 """
12361259 Decode latents using jax.lax.scan for temporal iteration.
12371260 This enables proper JIT compilation while managing memory efficiently.
1261+
1262+ Args:
1263+ z: Latent tensor
1264+ feat_cache: Cache object (for API compatibility, not used internally)
1265+ return_dict: Whether to return FlaxDecoderOutput or tuple
12381266 """
12391267 if z .shape [- 1 ] != self .z_dim :
12401268 z = jnp .transpose (z , (0 , 2 , 3 , 4 , 1 ))
0 commit comments