@@ -1259,7 +1259,7 @@ def __init__(
12591259 precision = precision ,
12601260 )
12611261
1262- @nnx .jit
1262+ @nnx .jit # JIT the whole encode method
12631263 def encode (
12641264 self , x : jax .Array , return_dict : bool = True
12651265 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
@@ -1269,13 +1269,14 @@ def encode(
12691269
12701270 b , t , h , w , c = x .shape
12711271 chunk_size = 4 # Process in chunks of 4 frames
1272+ # Assuming the encoder downsamples time by a factor of 4 overall
12721273
1273- # Calculate padding needed to make the time dimension a multiple of chunk_size
12741274 num_chunks = math .ceil (t / chunk_size )
12751275 padded_t = num_chunks * chunk_size
12761276 padding_t = padded_t - t
12771277
12781278 if padding_t > 0 :
1279+ # Pad the time dimension to be a multiple of chunk_size
12791280 paddings = [(0 , 0 )] * x .ndim
12801281 paddings [1 ] = (0 , padding_t ) # Pad at the end of the time dimension
12811282 x_padded = jnp .pad (x , paddings , mode = 'constant' , constant_values = 0.0 )
@@ -1288,23 +1289,26 @@ def encode(
12881289 # Swap axes for scan: (Num_Chunks, B, Chunk_T, H, W, C)
12891290 x_scannable = jnp .swapaxes (x_reshaped , 0 , 1 )
12901291
1291- # Define the function to be executed in each step of the scan
1292+ # Wrap the encoder's call method with jax.checkpoint
1293+ # nnx.Module instances are callable, so this works.
1294+ encoder_checkpointed = jax .checkpoint (self .encoder )
1295+
12921296 def scan_fn (dummy_carry , x_chunk ):
12931297 # x_chunk shape: (B, chunk_size, H, W, C)
12941298 b_c , _ , h_c , w_c , _ = x_chunk .shape
12951299
1296- # Reset cache for each chunk to ensure independence
1300+ # Reset cache for each chunk to ensure independence as per original logic
12971301 init_cache = self .encoder .init_cache (b_c , h_c , w_c , x_chunk .dtype )
12981302
1299- # Use gradient checkpointing to save memory
1300- out_chunk , _ = nnx .checkpoint (self .encoder )(x_chunk , init_cache )
1301- # Expected out_chunk shape: (B, 1, H', W', Z*2)
1302- # as each 4-frame chunk is downsampled temporally by 4x.
1303+ # Call the checkpointed encoder
1304+ out_chunk , _ = encoder_checkpointed (x_chunk , init_cache )
1305+ # Expected out_chunk shape: (B, 1, H', W', Z*2), assuming 4x temporal downsampling per chunk
13031306
13041307 return dummy_carry , out_chunk
13051308
1306- # Initial carry for scan - not used for state propagation between chunks
1307- initial_scan_carry = self .encoder .init_cache (b , h , w , x .dtype )
1309+ # The initial carry structure for scan needs to match the output carry structure of scan_fn.
1310+ # Since we don't propagate the cache *between* chunks, dummy_carry can be simple.
1311+ initial_scan_carry = {}
13081312
13091313 # Run the scan over the chunks
13101314 _ , encoded_chunks = jax .lax .scan (scan_fn , initial_scan_carry , x_scannable )
@@ -1314,11 +1318,10 @@ def scan_fn(dummy_carry, x_chunk):
13141318 # Transpose back to (B, num_chunks, 1, H', W', Z*2)
13151319 encoded_combined = jnp .swapaxes (encoded_chunks , 0 , 1 )
13161320
1317- # Reshape to (B, num_chunks, H', W', Z*2)
1321+ # Reshape to (B, num_chunks * 1, H', W', Z*2) -> (B, num_chunks , H', W', Z*2)
13181322 b_out , nc_out , t_out_chunk , h_out , w_out , c_out = encoded_combined .shape
13191323 encoded = encoded_combined .reshape ((b_out , nc_out * t_out_chunk , h_out , w_out , c_out ))
1320- # Final 'encoded' shape: (B, num_chunks, H', W', Z*2)
1321- # For T=9, num_chunks=3. This matches the expected (B, 3, H', W', Z*2)
1324+ # Final 'encoded' shape: (B, 3, H', W', Z*2) for T=9 input
13221325
13231326 # Post-processing to get distribution parameters
13241327 enc , _ = self .quant_conv (encoded , cache_x = None )
@@ -1350,7 +1353,8 @@ def scan_fn(carry, input_slice):
13501353 out_slice , new_carry = self .decoder (input_slice , carry )
13511354 return new_carry , out_slice
13521355
1353- final_cache , decoded_frames = jax .lax .scan (scan_fn , initial_scan_carry , x_scan )
1356+ # Need to provide a valid initial cache structure for the scan
1357+ final_cache , decoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
13541358
13551359 decoded = jnp .transpose (decoded_frames , (1 , 0 , 2 , 3 , 4 , 5 ))
13561360
0 commit comments