@@ -1202,21 +1202,55 @@ def encode(
12021202 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
12031203 assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
12041204
1205- x_scan = jnp .swapaxes (x , 0 , 1 )
12061205 b , t , h , w , c = x .shape
1207- init_cache = self . encoder . init_cache ( b , h , w , x . dtype )
1206+ all_outs = []
12081207
1209- def scan_fn (carry , input_slice ):
1210- # Expand Time dimension for Conv3d
1211- input_slice = jnp .expand_dims (input_slice , 1 )
1208+ def scan_fn_chunk (carry , input_slice ):
1209+ # input_slice shape is (B, H, W, C)
1210+ input_slice = jnp .expand_dims (input_slice , 1 ) # Shape (B, 1, H, W, C) for encoder
12121211 out_slice , new_carry = self .encoder (input_slice , carry )
1213- # Squeeze Time dimension for scan stacking
1214- out_slice = jnp .squeeze (out_slice , 1 )
1212+ # out_slice shape is (B, 1, H', W', C')
1213+ out_slice = jnp .squeeze (out_slice , 1 ) # Shape (B, H', W', C') for scan output
12151214 return new_carry , out_slice
12161215
1217- final_cache , encoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1218- encoded = jnp .swapaxes (encoded_frames , 0 , 1 )
1219- enc , _ = self .quant_conv (encoded )
1216+ # 1. Process the first frame
1217+ # Initialize cache for the first frame
1218+ init_cache_first = self .encoder .init_cache (b , h , w , x .dtype )
1219+ x_first_scan = jnp .expand_dims (x [:, 0 , ...], axis = 0 ) # Shape (1, B, H, W, C) for scan
1220+
1221+ _ , out_first_frames = jax .lax .scan (scan_fn_chunk , init_cache_first , x_first_scan )
1222+ # out_first_frames shape is (1, B, H', W', C')
1223+ all_outs .append (jnp .swapaxes (out_first_frames , 0 , 1 )) # Shape (B, 1, H', W', C')
1224+
1225+ # 2. Process subsequent Chunks of 4
1226+ if t > 1 :
1227+ num_chunks = (t - 1 + 3 ) // 4 # Ceiling division
1228+ for i in range (num_chunks ):
1229+ start_idx = 1 + 4 * i
1230+ end_idx = min (start_idx + 4 , t )
1231+
1232+ if start_idx >= t :
1233+ break
1234+
1235+ chunk = x [:, start_idx :end_idx , ...]
1236+ # Prepare chunk for scan: shape (T_chunk, B, H, W, C)
1237+ x_scan = jnp .swapaxes (chunk , 0 , 1 )
1238+
1239+ # *** Cache Reset for EACH CHUNK ***
1240+ init_cache_chunk = self .encoder .init_cache (b , h , w , x .dtype )
1241+
1242+ _ , encoded_frames_chunk = jax .lax .scan (scan_fn_chunk , init_cache_chunk , x_scan )
1243+ # encoded_frames_chunk shape is (T_chunk, B, H', W', C')
1244+
1245+ # Transpose back to (B, T_chunk, H', W', C')
1246+ out_chunk = jnp .swapaxes (encoded_frames_chunk , 0 , 1 )
1247+ all_outs .append (out_chunk )
1248+
1249+ # Concatenate results from all chunks along the time dimension
1250+ encoded = jnp .concatenate (all_outs , axis = 1 )
1251+
1252+ # Apply quant_conv - this layer also has a cache, but the old code didn't pipe it.
1253+ enc , _ = self .quant_conv (encoded , cache_x = None )
12201254
12211255 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
12221256 h = jnp .concatenate ([mu , logvar ], axis = - 1 )
0 commit comments