@@ -1258,34 +1258,31 @@ def __init__(
12581258 precision = precision ,
12591259 )
12601260
1261- # REMOVE @nnx.jit for now to ensure this logic runs
1261+ @nnx .jit
12621262 def encode (
12631263 self , x : jax .Array , return_dict : bool = True
12641264 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
12651265 if x .shape [- 1 ] != 3 :
1266+ # reshape channel last for JAX
12661267 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
12671268 assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
12681269
1270+ x_scan = jnp .swapaxes (x , 0 , 1 )
12691271 b , t , h , w , c = x .shape
12701272 init_cache = self .encoder .init_cache (b , h , w , x .dtype )
12711273
1272- # Process first frame
1273- out1 , _ = self .encoder (x [:, :1 , ...], init_cache )
1274-
1275- if t > 1 :
1276- # Process remaining frames in one chunk
1277- # We need to manage cache updates manually if not using scan
1278- # This part is tricky because the new encoder returns cache,
1279- # but the old logic didn't seem to carry cache between chunks.
1274+ def scan_fn (carry , input_slice ):
1275+ # Expand Time dimension for Conv3d
1276+ input_slice = jnp .expand_dims (input_slice , 1 )
1277+ out_slice , new_carry = self .encoder (input_slice , carry )
1278+ # Squeeze Time dimension for scan stacking
1279+ out_slice = jnp .squeeze (out_slice , 1 )
1280+ return new_carry , out_slice
12801281
1281- # Let's SIMPLIFY to match the OLD logic's spirit: Reset cache for the chunk
1282- init_cache_rest = self .encoder .init_cache (b , h , w , x .dtype )
1283- out_rest , _ = self .encoder (x [:, 1 :, ...], init_cache_rest )
1284- encoded = jnp .concatenate ([out1 , out_rest ], axis = 1 )
1285- else :
1286- encoded = out1
1282+ final_cache , encoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1283+ encoded = jnp .swapaxes (encoded_frames , 0 , 1 )
1284+ enc , _ = self .quant_conv (encoded )
12871285
1288- enc , _ = self .quant_conv (encoded , cache_x = None )
12891286 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
12901287 h = jnp .concatenate ([mu , logvar ], axis = - 1 )
12911288
@@ -1294,7 +1291,7 @@ def encode(
12941291 return (posterior ,)
12951292 return FlaxAutoencoderKLOutput (latent_dist = posterior )
12961293
1297- # @nnx.jit
1294+ @nnx .jit
12981295 def decode (
12991296 self , z : jax .Array , return_dict : bool = True
13001297 ) -> Union [FlaxDecoderOutput , jax .Array ]:
0 commit comments