@@ -1197,84 +1197,34 @@ def __init__(
11971197 def encode (
11981198 self , x : jax .Array , return_dict : bool = True
11991199 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
1200- if x .shape [- 1 ] != 3 :
1201- x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
1202-
1203- b , t , h , w , c = x .shape
1204-
1205- # --- STEP 1: Process First Frame (With Padding Hack) ---
1206- x_first = x [:, :1 , ...] # (B, 1, ...)
1207-
1208- # We PAD this single frame to T=4 so it survives the strides.
1209- # We will take only the first result.
1210- x_first_padded = jnp .concatenate ([x_first ] * 4 , axis = 1 ) # (B, 4, ...)
1211-
1212- # Initialize Cache
1213- init_cache = self .encoder .init_cache (b , h , w , x .dtype )
1214-
1215- # Run Encoder on padded first frame
1216- # We discard the cache update here because this is a "Fake" run to get the latent
1217- # BUT wait, we need the cache state for the next frames.
1218- # This is tricky. If we pad [0, 0, 0, 0], the cache will be filled with Frame 0's history.
1219- # This is actually correct for a static image or start of video.
1220-
1221- enc_first_padded , cache_after_first = self .encoder (x_first_padded , init_cache )
1222-
1223- # Take only the first frame of the output
1224- enc_first = enc_first_padded [:, :1 , ...]
1225-
1226- # --- STEP 2: Process Rest of Frames (Chunks of 4) ---
1227- x_rest = x [:, 1 :, ...]
1228- t_rest = t - 1
1229-
1230- # Pad remainder to be divisible by 4
1231- pad_len = (4 - (t_rest % 4 )) % 4
1232- if pad_len > 0 :
1233- last = x_rest [:, - 1 :, ...]
1234- padding = jnp .repeat (last , pad_len , axis = 1 )
1235- x_rest_padded = jnp .concatenate ([x_rest , padding ], axis = 1 )
1236- else :
1237- x_rest_padded = x_rest
1238-
1239- num_chunks = x_rest_padded .shape [1 ] // 4
1240- x_chunks = x_rest_padded .reshape (b , num_chunks , 4 , h , w , c )
1241- x_chunks = jnp .transpose (x_chunks , (1 , 0 , 2 , 3 , 4 , 5 ))
1242-
1243- # Scan Function
1244- def scan_fn (carry , input_chunk ):
1245- out_chunk , new_carry = self .encoder (input_chunk , carry )
1246- return new_carry , out_chunk
1247-
1248- final_cache , enc_rest_chunks = jax .lax .scan (scan_fn , cache_after_first , x_chunks )
1249-
1250- # Flatten Rest
1251- enc_rest_chunks = jnp .swapaxes (enc_rest_chunks , 0 , 1 )
1252- b_out , n_chunks , t_chunk , h_out , w_out , c_out = enc_rest_chunks .shape
1253- enc_rest = enc_rest_chunks .reshape (b_out , n_chunks * t_chunk , h_out , w_out , c_out )
1254-
1255- # Slice off padding from result if needed
1256- # We padded input by 'pad_len'. Output is downsampled by 4 (likely).
1257- # Actually, since we chunked by 4 and got 1 output, the mapping is 1-to-1 chunk-to-latent.
1258- # If we added 1 chunk of padding, we remove 1 frame of output.
1259- if pad_len > 0 :
1260- # We padded inputs. Does that mean we generated extra latents?
1261- # If t_rest=5. Pad to 8. Chunks=2. Output=2 latents.
1262- # Real latents needed: ceil(5/4) = 2.
1263- # So actually, we don't need to slice! The ceiling behavior is what we want.
1264- pass
1265-
1266- # Concatenate
1267- encoded = jnp .concatenate ([enc_first , enc_rest ], axis = 1 )
1268-
1269- # Quantize
1270- enc , _ = self .quant_conv (encoded )
1271- mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
1272- h_latents = jnp .concatenate ([mu , logvar ], axis = - 1 )
1273-
1274- posterior = FlaxDiagonalGaussianDistribution (h_latents )
1275- if not return_dict :
1276- return (posterior ,)
1277- return FlaxAutoencoderKLOutput (latent_dist = posterior )
1200+ if x .shape [- 1 ] != 3 :
1201+ # reshape channel last for JAX
1202+ x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
1203+ assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
1204+
1205+ x_scan = jnp .swapaxes (x , 0 , 1 )
1206+ b , t , h , w , c = x .shape
1207+ init_cache = self .encoder .init_cache (b , h , w , x .dtype )
1208+
1209+ def scan_fn (carry , input_slice ):
1210+ # Expand Time dimension for Conv3d
1211+ input_slice = jnp .expand_dims (input_slice , 1 )
1212+ out_slice , new_carry = self .encoder (input_slice , carry )
1213+ # Squeeze Time dimension for scan stacking
1214+ out_slice = jnp .squeeze (out_slice , 1 )
1215+ return new_carry , out_slice
1216+
1217+ final_cache , encoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1218+ encoded = jnp .swapaxes (encoded_frames , 0 , 1 )
1219+ enc , _ = self .quant_conv (encoded )
1220+
1221+ mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
1222+ h = jnp .concatenate ([mu , logvar ], axis = - 1 )
1223+
1224+ posterior = FlaxDiagonalGaussianDistribution (h )
1225+ if not return_dict :
1226+ return (posterior ,)
1227+ return FlaxAutoencoderKLOutput (latent_dist = posterior )
12781228
12791229 @nnx .jit
12801230 def decode (
0 commit comments