@@ -1197,61 +1197,76 @@ def __init__(
11971197 def encode (
11981198 self , x : jax .Array , return_dict : bool = True
11991199 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
1200- # 1. Standard Transpose Check (Matches Old)
12011200 if x .shape [- 1 ] != 3 :
12021201 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
1203- assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
1204-
1202+
12051203 b , t , h , w , c = x .shape
1206-
1207- # 2. Replicate "First Frame" Logic (Matches Old 'if i == 0')
1208- # The first frame is processed individually to prime the cache.
1209- x_first = x [:, :1 , ...] # Shape: (B, 1, H, W, C)
12101204
1211- # 3. Replicate "Chunking" Logic (Matches Old 'else: chunks of 4')
1212- # We take the remaining frames (Index 1 to End)
1213- x_rest = x [:, 1 :, ...] # Shape: (B, T-1, H, W, C)
1205+ # --- STEP 1: Process First Frame (With Padding Hack) ---
1206+ x_first = x [:, :1 , ...] # (B, 1, ...)
12141207
1215- # We assume the remaining frames are divisible by 4 (e.g. 80 frames)
1216- # Reshape to (Num_Chunks, B, 4, H, W, C) for the scan loop
1217- t_rest = t - 1
1218- assert t_rest % 4 == 0 , f"Remaining frames { t_rest } must be divisible by 4 (Total frames must be 1 + 4*k)"
1219- num_chunks = t_rest // 4
1208+ # We PAD this single frame to T=4 so it survives the strides.
1209+ # We will take only the first result.
1210+ x_first_padded = jnp .concatenate ([x_first ] * 4 , axis = 1 ) # (B, 4, ...)
12201211
1221- # Prepare for scan: Swap axis 0 and 1 so 'num_chunks' is the scan iterator
1222- x_chunks = x_rest .reshape (b , num_chunks , 4 , h , w , c )
1223- x_chunks = jnp .transpose (x_chunks , (1 , 0 , 2 , 3 , 4 , 5 ))
1224-
1225- # 4. Initialize Cache
1212+ # Initialize Cache
12261213 init_cache = self .encoder .init_cache (b , h , w , x .dtype )
1214+
1215+ # Run Encoder on padded first frame
1216+ # We discard the cache update here because this is a "Fake" run to get the latent
1217+ # BUT wait, we need the cache state for the next frames.
1218+ # This is tricky. If we pad [0, 0, 0, 0], the cache will be filled with Frame 0's history.
1219+ # This is actually correct for a static image or start of video.
1220+
1221+ enc_first_padded , cache_after_first = self .encoder (x_first_padded , init_cache )
1222+
1223+ # Take only the first frame of the output
1224+ enc_first = enc_first_padded [:, :1 , ...]
12271225
1228- # 5. Execute First Frame
1229- # This corresponds to the 'i=0' iteration in the old loop.
1230- enc_first , cache_after_first = self .encoder (x_first , init_cache )
1226+ # --- STEP 2: Process Rest of Frames (Chunks of 4) ---
1227+ x_rest = x [:, 1 :, ...]
1228+ t_rest = t - 1
1229+
1230+ # Pad remainder to be divisible by 4
1231+ pad_len = (4 - (t_rest % 4 )) % 4
1232+ if pad_len > 0 :
1233+ last = x_rest [:, - 1 :, ...]
1234+ padding = jnp .repeat (last , pad_len , axis = 1 )
1235+ x_rest_padded = jnp .concatenate ([x_rest , padding ], axis = 1 )
1236+ else :
1237+ x_rest_padded = x_rest
1238+
1239+ num_chunks = x_rest_padded .shape [1 ] // 4
1240+ x_chunks = x_rest_padded .reshape (b , num_chunks , 4 , h , w , c )
1241+ x_chunks = jnp .transpose (x_chunks , (1 , 0 , 2 , 3 , 4 , 5 ))
12311242
1232- # 6. Execute Scan on Chunks
1233- # This corresponds to the 'i > 0' iterations in the old loop.
1243+ # Scan Function
12341244 def scan_fn (carry , input_chunk ):
1235- # input_chunk is (B, 4, H, W, C).
1236- # The encoder naturally consumes 4 frames and outputs 1 latent frame (due to stride)
12371245 out_chunk , new_carry = self .encoder (input_chunk , carry )
12381246 return new_carry , out_chunk
12391247
12401248 final_cache , enc_rest_chunks = jax .lax .scan (scan_fn , cache_after_first , x_chunks )
12411249
1242- # 7. Flatten and Reassemble
1243- # enc_rest_chunks: (Num_Chunks, B, T_latent_chunk, ...)
1244- # We swap back to (B, Num_Chunks, ...) and flatten
1250+ # Flatten Rest
12451251 enc_rest_chunks = jnp .swapaxes (enc_rest_chunks , 0 , 1 )
1252+ b_out , n_chunks , t_chunk , h_out , w_out , c_out = enc_rest_chunks .shape
1253+ enc_rest = enc_rest_chunks .reshape (b_out , n_chunks * t_chunk , h_out , w_out , c_out )
12461254
1247- # Flatten the chunks into a continuous sequence
1248- b_out , n_chunks , t_chunk_out , h_out , w_out , c_out = enc_rest_chunks .shape
1249- enc_rest = enc_rest_chunks .reshape (b_out , n_chunks * t_chunk_out , h_out , w_out , c_out )
1250-
1251- # Concatenate: [First Frame Result] + [Rest of Frames Result]
1255+ # Slice off padding from result if needed
1256+ # We padded input by 'pad_len'. Output is downsampled by 4 (likely).
1257+ # Actually, since we chunked by 4 and got 1 output, the mapping is 1-to-1 chunk-to-latent.
1258+ # If we added 1 chunk of padding, we remove 1 frame of output.
1259+ if pad_len > 0 :
1260+ # We padded inputs. Does that mean we generated extra latents?
1261+ # If t_rest=5. Pad to 8. Chunks=2. Output=2 latents.
1262+ # Real latents needed: ceil(5/4) = 2.
1263+ # So actually, we don't need to slice! The ceiling behavior is what we want.
1264+ pass
1265+
1266+ # Concatenate
12521267 encoded = jnp .concatenate ([enc_first , enc_rest ], axis = 1 )
12531268
1254- # 8. Post-Processing (Matches Old Logic exactly)
1269+ # Quantize
12551270 enc , _ = self .quant_conv (encoded )
12561271 mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
12571272 h_latents = jnp .concatenate ([mu , logvar ], axis = - 1 )
0 commit comments