@@ -1263,51 +1263,52 @@ def _encode_jit(self, x: jax.Array) -> jax.Array:
12631263 """Contains the core JAX computations for encoding, suitable for JIT."""
12641264 if x .shape [- 1 ] != 3 :
12651265 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
1266- # assert x.shape[-1] == 3, "Input channels must be 3" # Assertions might not be ideal in JIT
12671266
12681267 b , t , h , w , c = x .shape
1269- chunk_size = 4 # Process in chunks of 4 frames
1270-
1271- num_chunks = math .ceil (t / chunk_size )
1272- padded_t = num_chunks * chunk_size
1273- padding_t = padded_t - t
1274-
1275- if padding_t > 0 :
1276- # Pad the time dimension to be a multiple of chunk_size
1277- paddings = [(0 , 0 )] * x .ndim
1278- paddings [1 ] = (0 , padding_t ) # Pad at the end of the time dimension
1279- x_padded = jnp .pad (x , paddings , mode = 'constant' , constant_values = 0.0 )
1280- else :
1281- x_padded = x
1282-
1283- # Reshape for scan: (B, Num_Chunks, Chunk_T, H, W, C)
1284- x_reshaped = x_padded .reshape ((b , num_chunks , chunk_size , h , w , c ))
1285-
1286- # Swap axes for scan: (Num_Chunks, B, Chunk_T, H, W, C)
1287- x_scannable = jnp .swapaxes (x_reshaped , 0 , 1 )
1268+ all_outs = []
12881269
1270+ # Process the first frame (Time=1)
1271+ x_first = x [:, :1 , ...]
1272+ init_cache_first = self .encoder .init_cache (b , h , w , x_first .dtype )
12891273 encoder_checkpointed = jax .checkpoint (self .encoder )
1274+ out1 , state_carry = encoder_checkpointed (x_first , init_cache_first )
1275+ all_outs .append (out1 )
1276+
1277+ # Process the remaining frames using scan over chunks of 4
1278+ if t > 1 :
1279+ x_rest = x [:, 1 :, ...]
1280+ t_rest = x_rest .shape [1 ]
1281+ chunk_size = 4
1282+
1283+ num_chunks = math .ceil (t_rest / chunk_size )
1284+ padded_t_rest = num_chunks * chunk_size
1285+ padding_t = padded_t_rest - t_rest
1286+
1287+ if padding_t > 0 :
1288+ paddings = [(0 , 0 )] * x_rest .ndim
1289+ paddings [1 ] = (0 , padding_t ) # Pad at the end
1290+ x_rest_padded = jnp .pad (x_rest , paddings , mode = 'constant' , constant_values = 0.0 )
1291+ else :
1292+ x_rest_padded = x_rest
12901293
1291- def scan_fn (dummy_carry , x_chunk ):
1292- # x_chunk shape: (B, chunk_size, H, W, C)
1293- b_c , _ , h_c , w_c , _ = x_chunk .shape
1294- init_cache = self .encoder .init_cache (b_c , h_c , w_c , x_chunk .dtype )
1295- out_chunk , _ = encoder_checkpointed (x_chunk , init_cache )
1296- return dummy_carry , out_chunk
1294+ x_reshaped = x_rest_padded .reshape ((b , num_chunks , chunk_size , h , w , c ))
1295+ x_scannable = jnp .swapaxes (x_reshaped , 0 , 1 )
12971296
1298- initial_scan_carry = {}
1299- _ , encoded_chunks = jax .lax .scan (scan_fn , initial_scan_carry , x_scannable )
1297+ def scan_fn (carry_state , x_chunk ):
1298+ out_chunk , new_state = encoder_checkpointed (x_chunk , carry_state )
1299+ return new_state , out_chunk
13001300
1301- encoded_combined = jnp .swapaxes (encoded_chunks , 0 , 1 )
1301+ _ , encoded_chunks = jax .lax .scan (scan_fn , state_carry , x_scannable )
1302+ encoded_rest = jnp .swapaxes (encoded_chunks , 0 , 1 )
1303+ b_out , nc_out , t_out_chunk , h_out , w_out , c_out = encoded_rest .shape
1304+ encoded_rest = encoded_rest .reshape ((b_out , nc_out * t_out_chunk , h_out , w_out , c_out ))
13021305
1303- b_out , nc_out , t_out_chunk , h_out , w_out , c_out = encoded_combined .shape
1304- encoded = encoded_combined .reshape ((b_out , nc_out * t_out_chunk , h_out , w_out , c_out ))
1306+ all_outs .append (encoded_rest )
1307+
1308+ encoded = jnp .concatenate (all_outs , axis = 1 )
13051309
13061310 enc , _ = self .quant_conv (encoded , cache_x = None )
1307- # mu = enc[..., :self.z_dim]
1308- # logvar = enc[..., self.z_dim:]
1309- # h = jnp.concatenate([mu, logvar], axis=-1)
1310- return enc # Return the direct output of quant_conv
1311+ return enc
13111312
13121313 # JIT compile the internal JAX-based function
13131314 _encode_compiled = nnx .jit (_encode_jit )
@@ -1316,48 +1317,59 @@ def encode(
13161317 self , x : jax .Array , return_dict : bool = True
13171318 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
13181319 """Encodes the input array and returns custom distribution objects."""
1319- if x .shape [- 1 ] != 3 :
1320- # Transpose in the non-JIT part if needed, though _encode_jit handles it too
1321- pass # Handled inside _encode_jit
1322-
1323- # Call the JIT-compiled function to get the raw encoded array
13241320 h_params = self ._encode_compiled (x )
1325-
1326- # Create the custom Python objects from the JAX array results
13271321 posterior = FlaxDiagonalGaussianDistribution (h_params )
1328-
13291322 if not return_dict :
13301323 return (posterior ,)
13311324 return FlaxAutoencoderKLOutput (latent_dist = posterior )
13321325
1333- @nnx .jit
1334- def decode (
1335- self , z : jax .Array , return_dict : bool = True
1336- ) -> Union [FlaxDecoderOutput , jax .Array ]:
1326+ def _decode_jit (self , z : jax .Array ) -> jax .Array :
1327+ """Core JAX decoding logic with scan and frame swapping."""
13371328 if z .shape [- 1 ] != self .z_dim :
13381329 z = jnp .transpose (z , (0 , 2 , 3 , 4 , 1 ))
1339-
13401330 x , _ = self .post_quant_conv (z )
1341- x_scan = jnp .swapaxes (x , 0 , 1 )
1342-
1331+
13431332 b , t , h , w , c = x .shape
13441333 init_cache = self .decoder .init_cache (b , h , w , x .dtype )
1334+ decoder_checkpointed = jax .checkpoint (self .decoder )
1335+
1336+ all_decoded = []
1337+ x_first = x [:, :1 , ...]
1338+ out_first , state_carry = decoder_checkpointed (x_first , init_cache )
1339+ all_decoded .append (out_first )
1340+ if t > 1 :
1341+ x_rest = x [:, 1 :, ...]
1342+ x_scan = jnp .swapaxes (x_rest , 0 , 1 )
1343+
1344+ def scan_fn (carry , input_slice ):
1345+ input_slice = jnp .expand_dims (input_slice , 1 )
1346+ out_slice , new_carry = decoder_checkpointed (input_slice , carry )
1347+ out_swapped = out_slice [:, jnp .array ([0 , 2 , 1 , 3 ]), ...]
1348+
1349+ return new_carry , out_swapped
1350+
1351+ _ , decoded_rest = jax .lax .scan (scan_fn , state_carry , x_scan )
1352+
1353+ decoded_rest = jnp .swapaxes (decoded_rest , 0 , 1 )
1354+
1355+ b_r , t_r , sub_t , h_r , w_r , c_r = decoded_rest .shape
1356+ decoded_rest = decoded_rest .reshape (b_r , t_r * sub_t , h_r , w_r , c_r )
1357+
1358+ all_decoded .append (decoded_rest )
1359+
1360+ out = jnp .concatenate (all_decoded , axis = 1 )
1361+ out = jnp .clip (out , min = - 1.0 , max = 1.0 )
1362+
1363+ return out
1364+ _decode_compiled = nnx .jit (_decode_jit )
13451365
1346- def scan_fn (carry , input_slice ):
1347- input_slice = jnp .expand_dims (input_slice , 1 )
1348- out_slice , new_carry = self .decoder (input_slice , carry )
1349- return new_carry , out_slice
1350-
1351- # Need to provide a valid initial cache structure for the scan
1352- final_cache , decoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1353-
1354- decoded = jnp .transpose (decoded_frames , (1 , 0 , 2 , 3 , 4 , 5 ))
1355-
1356- b , t_lat , t_sub , h , w , c = decoded .shape
1357- decoded = decoded .reshape (b , t_lat * t_sub , h , w , c )
1358-
1359- out = jnp .clip (decoded , min = - 1.0 , max = 1.0 )
1366+ def decode (
1367+ self , z : jax .Array , return_dict : bool = True
1368+ ) -> Union [FlaxDecoderOutput , jax .Array ]:
1369+
1370+ decoded = self ._decode_compiled (z )
13601371
13611372 if not return_dict :
1362- return (out ,)
1363- return FlaxDecoderOutput (sample = out )
1373+ return (decoded ,)
1374+
1375+ return FlaxDecoderOutput (sample = decoded )
0 commit comments