@@ -1197,34 +1197,69 @@ def __init__(
11971197 def encode (
11981198 self , x : jax .Array , return_dict : bool = True
11991199 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
1200- if x .shape [- 1 ] != 3 :
1201- # reshape channel last for JAX
1202- x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
1203- assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
1204-
1205- x_scan = jnp .swapaxes (x , 0 , 1 )
1206- b , t , h , w , c = x .shape
1207- init_cache = self .encoder .init_cache (b , h , w , x .dtype )
1208-
1209- def scan_fn (carry , input_slice ):
1210- # Expand Time dimension for Conv3d
1211- input_slice = jnp .expand_dims (input_slice , 1 )
1212- out_slice , new_carry = self .encoder (input_slice , carry )
1213- # Squeeze Time dimension for scan stacking
1214- out_slice = jnp .squeeze (out_slice , 1 )
1215- return new_carry , out_slice
1216-
1217- final_cache , encoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1218- encoded = jnp .swapaxes (encoded_frames , 0 , 1 )
1219- enc , _ = self .quant_conv (encoded )
1220-
1221- mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
1222- h = jnp .concatenate ([mu , logvar ], axis = - 1 )
1223-
1224- posterior = FlaxDiagonalGaussianDistribution (h )
1225- if not return_dict :
1226- return (posterior ,)
1227- return FlaxAutoencoderKLOutput (latent_dist = posterior )
1200+ # 1. Standard Transpose Check (Matches Old)
1201+ if x .shape [- 1 ] != 3 :
1202+ x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
1203+ assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
1204+
1205+ b , t , h , w , c = x .shape
1206+
1207+ # 2. Replicate "First Frame" Logic (Matches Old 'if i == 0')
1208+ # The first frame is processed individually to prime the cache.
1209+ x_first = x [:, :1 , ...] # Shape: (B, 1, H, W, C)
1210+
1211+ # 3. Replicate "Chunking" Logic (Matches Old 'else: chunks of 4')
1212+ # We take the remaining frames (Index 1 to End)
1213+ x_rest = x [:, 1 :, ...] # Shape: (B, T-1, H, W, C)
1214+
1215+ # We assume the remaining frames are divisible by 4 (e.g. 80 frames)
1216+ # Reshape to (Num_Chunks, B, 4, H, W, C) for the scan loop
1217+ t_rest = t - 1
1218+ assert t_rest % 4 == 0 , f"Remaining frames { t_rest } must be divisible by 4 (Total frames must be 1 + 4*k)"
1219+ num_chunks = t_rest // 4
1220+
1221+ # Prepare for scan: Swap axis 0 and 1 so 'num_chunks' is the scan iterator
1222+ x_chunks = x_rest .reshape (b , num_chunks , 4 , h , w , c )
1223+ x_chunks = jnp .transpose (x_chunks , (1 , 0 , 2 , 3 , 4 , 5 ))
1224+
1225+ # 4. Initialize Cache
1226+ init_cache = self .encoder .init_cache (b , h , w , x .dtype )
1227+
1228+ # 5. Execute First Frame
1229+ # This corresponds to the 'i=0' iteration in the old loop.
1230+ enc_first , cache_after_first = self .encoder (x_first , init_cache )
1231+
1232+ # 6. Execute Scan on Chunks
1233+ # This corresponds to the 'i > 0' iterations in the old loop.
1234+ def scan_fn (carry , input_chunk ):
1235+ # input_chunk is (B, 4, H, W, C).
1236+ # The encoder naturally consumes 4 frames and outputs 1 latent frame (due to stride)
1237+ out_chunk , new_carry = self .encoder (input_chunk , carry )
1238+ return new_carry , out_chunk
1239+
1240+ final_cache , enc_rest_chunks = jax .lax .scan (scan_fn , cache_after_first , x_chunks )
1241+
1242+ # 7. Flatten and Reassemble
1243+ # enc_rest_chunks: (Num_Chunks, B, T_latent_chunk, ...)
1244+ # We swap back to (B, Num_Chunks, ...) and flatten
1245+ enc_rest_chunks = jnp .swapaxes (enc_rest_chunks , 0 , 1 )
1246+
1247+ # Flatten the chunks into a continuous sequence
1248+ b_out , n_chunks , t_chunk_out , h_out , w_out , c_out = enc_rest_chunks .shape
1249+ enc_rest = enc_rest_chunks .reshape (b_out , n_chunks * t_chunk_out , h_out , w_out , c_out )
1250+
1251+ # Concatenate: [First Frame Result] + [Rest of Frames Result]
1252+ encoded = jnp .concatenate ([enc_first , enc_rest ], axis = 1 )
1253+
1254+ # 8. Post-Processing (Matches Old Logic exactly)
1255+ enc , _ = self .quant_conv (encoded )
1256+ mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
1257+ h_latents = jnp .concatenate ([mu , logvar ], axis = - 1 )
1258+
1259+ posterior = FlaxDiagonalGaussianDistribution (h_latents )
1260+ if not return_dict :
1261+ return (posterior ,)
1262+ return FlaxAutoencoderKLOutput (latent_dist = posterior )
12281263
12291264 @nnx .jit
12301265 def decode (
0 commit comments