@@ -1258,49 +1258,79 @@ def __init__(
12581258 precision = precision ,
12591259 )
12601260
1261- def _encode_jit (self , x : jax .Array ) -> jax .Array :
1262- """Core computation part to be JIT-compiled."""
1261+ @nnx .jit
1262+ def encode (
1263+ self , x : jax .Array , return_dict : bool = True
1264+ ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
12631265 if x .shape [- 1 ] != 3 :
1264- # reshape channel last for JAX
12651266 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
1266- assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x . shape } "
1267+ assert x .shape [- 1 ] == 3 , "Input channels must be 3 "
12671268
1268- x_scan = jnp .swapaxes (x , 0 , 1 )
12691269 b , t , h , w , c = x .shape
1270- init_cache = self . encoder . init_cache ( b , h , w , x . dtype )
1270+ chunk_size = 4 # Process in chunks of 4 frames
12711271
1272- def scan_fn (carry , input_slice ):
1273- # Expand Time dimension for Conv3d
1274- input_slice = jnp .expand_dims (input_slice , 1 )
1275- out_slice , new_carry = self .encoder (input_slice , carry )
1276- # Squeeze Time dimension for scan stacking
1277- out_slice = jnp .squeeze (out_slice , 1 )
1278- return new_carry , out_slice
1272+ # Calculate padding needed to make the time dimension a multiple of chunk_size
1273+ num_chunks = math .ceil (t / chunk_size )
1274+ padded_t = num_chunks * chunk_size
1275+ padding_t = padded_t - t
12791276
1280- final_cache , encoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1281- encoded = jnp .swapaxes (encoded_frames , 0 , 1 )
1282- enc , _ = self .quant_conv (encoded )
1277+ if padding_t > 0 :
1278+ paddings = [(0 , 0 )] * x .ndim
1279+ paddings [1 ] = (0 , padding_t ) # Pad at the end of the time dimension
1280+ x_padded = jnp .pad (x , paddings , mode = 'constant' , constant_values = 0.0 )
1281+ else :
1282+ x_padded = x
12831283
1284- # h contains the parameters for the distribution
1285- h = enc # Or jnp.concatenate([mu, logvar], axis=-1) as originally
1286- return h
1287- _encode_compiled = nnx .jit (_encode_jit )
1284+ # Reshape for scan: (B, Num_Chunks, Chunk_T, H, W, C)
1285+ x_reshaped = x_padded .reshape ((b , num_chunks , chunk_size , h , w , c ))
12881286
1289- def encode (
1290- self , x : jax .Array , return_dict : bool = True
1291- ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
1292- """Encodes the input, returning standard distribution objects."""
1293- # Call the compiled function to get JAX arrays
1294- h = self ._encode_compiled (x )
1287+ # Swap axes for scan: (Num_Chunks, B, Chunk_T, H, W, C)
1288+ x_scannable = jnp .swapaxes (x_reshaped , 0 , 1 )
1289+
1290+ # Define the function to be executed in each step of the scan
1291+ def scan_fn (dummy_carry , x_chunk ):
1292+ # x_chunk shape: (B, chunk_size, H, W, C)
1293+ b_c , _ , h_c , w_c , _ = x_chunk .shape
1294+
1295+ # Reset cache for each chunk to ensure independence
1296+ init_cache = self .encoder .init_cache (b_c , h_c , w_c , x_chunk .dtype )
1297+
1298+ # Use gradient checkpointing to save memory
1299+ out_chunk , _ = nnx .checkpoint (self .encoder )(x_chunk , init_cache )
1300+ # Expected out_chunk shape: (B, 1, H', W', Z*2)
1301+ # as each 4-frame chunk is downsampled temporally by 4x.
1302+
1303+ return dummy_carry , out_chunk
1304+
1305+ # Initial carry for scan - not used for state propagation between chunks
1306+ initial_scan_carry = self .encoder .init_cache (b , h , w , x .dtype )
1307+
1308+ # Run the scan over the chunks
1309+ _ , encoded_chunks = jax .lax .scan (scan_fn , initial_scan_carry , x_scannable )
1310+ # encoded_chunks shape: (num_chunks, B, 1, H', W', Z*2)
1311+
1312+ # Concatenate the results from each chunk
1313+ # Transpose back to (B, num_chunks, 1, H', W', Z*2)
1314+ encoded_combined = jnp .swapaxes (encoded_chunks , 0 , 1 )
1315+
1316+ # Reshape to (B, num_chunks, H', W', Z*2)
1317+ b_out , nc_out , t_out_chunk , h_out , w_out , c_out = encoded_combined .shape
1318+ encoded = encoded_combined .reshape ((b_out , nc_out * t_out_chunk , h_out , w_out , c_out ))
1319+ # Final 'encoded' shape: (B, num_chunks, H', W', Z*2)
1320+ # For T=9, num_chunks=3. This matches the expected (B, 3, H', W', Z*2)
1321+
1322+ # Post-processing to get distribution parameters
1323+ enc , _ = self .quant_conv (encoded , cache_x = None )
1324+ mu = enc [..., :self .z_dim ]
1325+ logvar = enc [..., self .z_dim :]
1326+ h = jnp .concatenate ([mu , logvar ], axis = - 1 )
12951327
1296- # Create custom objects outside the JIT scope
12971328 posterior = FlaxDiagonalGaussianDistribution (h )
12981329
12991330 if not return_dict :
1300- return (posterior ,)
1331+ return (posterior ,)
13011332 return FlaxAutoencoderKLOutput (latent_dist = posterior )
13021333
1303-
13041334 @nnx .jit
13051335 def decode (
13061336 self , z : jax .Array , return_dict : bool = True
@@ -1315,22 +1345,14 @@ def decode(
13151345 init_cache = self .decoder .init_cache (b , h , w , x .dtype )
13161346
13171347 def scan_fn (carry , input_slice ):
1318- # Expand Time dimension for Conv3d
13191348 input_slice = jnp .expand_dims (input_slice , 1 )
1320- # OPTIMIZATION: Force bfloat16 accumulation within the scan
1321- # to save memory on the massive output buffer
13221349 out_slice , new_carry = self .decoder (input_slice , carry )
1323- # out_slice = out_slice.astype(jnp.bfloat16)
13241350 return new_carry , out_slice
13251351
1326- final_cache , decoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1352+ final_cache , decoded_frames = jax .lax .scan (scan_fn , initial_scan_carry , x_scan )
13271353
1328- # decoded_frames shape: (T_lat, B, 4, H, W, C)
1329- # We need to flatten T_lat and 4.
1330- # Transpose to (B, T_lat, 4, H, W, C)
13311354 decoded = jnp .transpose (decoded_frames , (1 , 0 , 2 , 3 , 4 , 5 ))
13321355
1333- # Reshape to (B, T_lat*4, H, W, C)
13341356 b , t_lat , t_sub , h , w , c = decoded .shape
13351357 decoded = decoded .reshape (b , t_lat * t_sub , h , w , c )
13361358
0 commit comments