@@ -1259,17 +1259,14 @@ def __init__(
12591259 precision = precision ,
12601260 )
12611261
1262- @nnx .jit # JIT the whole encode method
1263- def encode (
1264- self , x : jax .Array , return_dict : bool = True
1265- ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
1262+ def _encode_jit (self , x : jax .Array ) -> jax .Array :
1263+ """Contains the core JAX computations for encoding, suitable for JIT."""
12661264 if x .shape [- 1 ] != 3 :
12671265 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
1268- assert x .shape [- 1 ] == 3 , "Input channels must be 3"
1266+ # assert x.shape[-1] == 3, "Input channels must be 3" # Assertions might not be ideal in JIT
12691267
12701268 b , t , h , w , c = x .shape
12711269 chunk_size = 4 # Process in chunks of 4 frames
1272- # Assuming the encoder downsamples time by a factor of 4 overall
12731270
12741271 num_chunks = math .ceil (t / chunk_size )
12751272 padded_t = num_chunks * chunk_size
@@ -1289,47 +1286,45 @@ def encode(
12891286 # Swap axes for scan: (Num_Chunks, B, Chunk_T, H, W, C)
12901287 x_scannable = jnp .swapaxes (x_reshaped , 0 , 1 )
12911288
1292- # Wrap the encoder's call method with jax.checkpoint
1293- # nnx.Module instances are callable, so this works.
12941289 encoder_checkpointed = jax .checkpoint (self .encoder )
12951290
12961291 def scan_fn (dummy_carry , x_chunk ):
12971292 # x_chunk shape: (B, chunk_size, H, W, C)
12981293 b_c , _ , h_c , w_c , _ = x_chunk .shape
1299-
1300- # Reset cache for each chunk to ensure independence as per original logic
13011294 init_cache = self .encoder .init_cache (b_c , h_c , w_c , x_chunk .dtype )
1302-
1303- # Call the checkpointed encoder
13041295 out_chunk , _ = encoder_checkpointed (x_chunk , init_cache )
1305- # Expected out_chunk shape: (B, 1, H', W', Z*2), assuming 4x temporal downsampling per chunk
1306-
13071296 return dummy_carry , out_chunk
13081297
1309- # The initial carry structure for scan needs to match the output carry structure of scan_fn.
1310- # Since we don't propagate the cache *between* chunks, dummy_carry can be simple.
13111298 initial_scan_carry = {}
1312-
1313- # Run the scan over the chunks
13141299 _ , encoded_chunks = jax .lax .scan (scan_fn , initial_scan_carry , x_scannable )
1315- # encoded_chunks shape: (num_chunks, B, 1, H', W', Z*2)
13161300
1317- # Concatenate the results from each chunk
1318- # Transpose back to (B, num_chunks, 1, H', W', Z*2)
13191301 encoded_combined = jnp .swapaxes (encoded_chunks , 0 , 1 )
13201302
1321- # Reshape to (B, num_chunks * 1, H', W', Z*2) -> (B, num_chunks, H', W', Z*2)
13221303 b_out , nc_out , t_out_chunk , h_out , w_out , c_out = encoded_combined .shape
13231304 encoded = encoded_combined .reshape ((b_out , nc_out * t_out_chunk , h_out , w_out , c_out ))
1324- # Final 'encoded' shape: (B, 3, H', W', Z*2) for T=9 input
13251305
1326- # Post-processing to get distribution parameters
13271306 enc , _ = self .quant_conv (encoded , cache_x = None )
1328- mu = enc [..., :self .z_dim ]
1329- logvar = enc [..., self .z_dim :]
1330- h = jnp .concatenate ([mu , logvar ], axis = - 1 )
1307+ # mu = enc[..., :self.z_dim]
1308+ # logvar = enc[..., self.z_dim:]
1309+ # h = jnp.concatenate([mu, logvar], axis=-1)
1310+ return enc # Return the direct output of quant_conv
1311+
1312+ # JIT compile the internal JAX-based function
1313+ _encode_compiled = nnx .jit (_encode_jit )
1314+
1315+ def encode (
1316+ self , x : jax .Array , return_dict : bool = True
1317+ ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
1318+ """Encodes the input array and returns custom distribution objects."""
1319+ if x .shape [- 1 ] != 3 :
1320+ # Transpose in the non-JIT part if needed, though _encode_jit handles it too
1321+ pass # Handled inside _encode_jit
1322+
1323+ # Call the JIT-compiled function to get the raw encoded array
1324+ h_params = self ._encode_compiled (x )
13311325
1332- posterior = FlaxDiagonalGaussianDistribution (h )
1326+ # Create the custom Python objects from the JAX array results
1327+ posterior = FlaxDiagonalGaussianDistribution (h_params )
13331328
13341329 if not return_dict :
13351330 return (posterior ,)
0 commit comments