@@ -1165,10 +1165,11 @@ def encode(
11651165 self , x : jax .Array , feat_cache : Optional [AutoencoderKLWanCache ] = None , return_dict : bool = True
11661166 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
11671167 """
1168- Encode video. Process the full video at once to handle temporal downsampling.
1168+ Encode video using jax.lax.scan for memory-efficient frame-by-frame processing.
1169+ Uses caching to maintain temporal context across frames.
11691170
11701171 Args:
1171- x: Input video tensor
1172+ x: Input video tensor (B, T, H, W, C)
11721173 feat_cache: Cache object (for API compatibility, not used internally)
11731174 return_dict: Whether to return FlaxAutoencoderKLOutput or tuple
11741175 """
@@ -1177,11 +1178,36 @@ def encode(
11771178
11781179 b , t , h , w , c = x .shape
11791180
1180- # Initialize cache once
1181+ # Initialize cache for encoder
11811182 init_cache = self .encoder .init_cache (b , h , w , x .dtype )
11821183
1183- # Process the full video through encoder (handles temporal downsampling internally)
1184- encoded , _ = self .encoder (x , init_cache )
1184+ # Prepare for scan: swap batch and time to iterate over time
1185+ x_scan = jnp .swapaxes (x , 0 , 1 ) # (T, B, H, W, C)
1186+
1187+ def scan_fn (carry , x_frame ):
1188+ """Process one frame at a time with cache."""
1189+ # x_frame shape: (B, H, W, C)
1190+ # Add time dimension for processing
1191+ x_frame = jnp .expand_dims (x_frame , axis = 1 ) # (B, 1, H, W, C)
1192+
1193+ # Process through encoder with cache
1194+ out_frame , new_cache = self .encoder (x_frame , carry )
1195+
1196+ # out_frame shape: (B, T_out, H', W', C') where T_out depends on temporal downsampling
1197+ # For stability, we keep the time dimension
1198+ return new_cache , out_frame
1199+
1200+ # Scan over time dimension
1201+ final_cache , encoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1202+ # encoded_frames shape: (T, B, T_out_per_frame, H', W', C')
1203+
1204+ # Concatenate along time dimension: (T, B, T_out, H', W', C') -> (B, T*T_out, H', W', C')
1205+ # First, swap to (B, T, T_out, H', W', C')
1206+ encoded_frames = jnp .swapaxes (encoded_frames , 0 , 1 )
1207+
1208+ # Then reshape to merge time dimensions
1209+ shape = encoded_frames .shape
1210+ encoded = encoded_frames .reshape (shape [0 ], shape [1 ] * shape [2 ], shape [3 ], shape [4 ], shape [5 ])
11851211
11861212 # Apply quantization convolution
11871213 enc , _ = self .quant_conv (encoded )
@@ -1198,10 +1224,11 @@ def decode(
11981224 self , z : jax .Array , feat_cache : Optional [AutoencoderKLWanCache ] = None , return_dict : bool = True
11991225 ) -> Union [FlaxDecoderOutput , jax .Array ]:
12001226 """
1201- Decode latents. Process the full latent at once to handle temporal upsampling.
1227+ Decode latents using jax.lax.scan for memory-efficient frame-by-frame processing.
1228+ Uses caching to maintain temporal context across frames.
12021229
12031230 Args:
1204- z: Latent tensor
1231+ z: Latent tensor (B, T, H, W, C)
12051232 feat_cache: Cache object (for API compatibility, not used internally)
12061233 return_dict: Whether to return FlaxDecoderOutput or tuple
12071234 """
@@ -1213,11 +1240,35 @@ def decode(
12131240
12141241 b , t , h , w , c = x .shape
12151242
1216- # Initialize cache once
1243+ # Initialize cache for decoder
12171244 init_cache = self .decoder .init_cache (b , h , w , x .dtype )
1218-
1219- # Process the full latent through decoder (handles temporal upsampling internally)
1220- decoded , _ = self .decoder (x , init_cache )
1245+
1246+ # Prepare for scan: swap batch and time to iterate over time
1247+ x_scan = jnp .swapaxes (x , 0 , 1 ) # (T, B, H, W, C)
1248+
1249+ def scan_fn (carry , x_frame ):
1250+ """Process one latent frame at a time with cache."""
1251+ # x_frame shape: (B, H, W, C)
1252+ # Add time dimension for processing
1253+ x_frame = jnp .expand_dims (x_frame , axis = 1 ) # (B, 1, H, W, C)
1254+
1255+ # Process through decoder with cache (will upsample temporally)
1256+ out_frame , new_cache = self .decoder (x_frame , carry )
1257+
1258+ # out_frame shape: (B, T_out, H', W', C') where T_out depends on temporal upsampling
1259+ return new_cache , out_frame
1260+
1261+ # Scan over time dimension
1262+ final_cache , decoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1263+ # decoded_frames shape: (T, B, T_out_per_frame, H', W', C')
1264+
1265+ # Concatenate along time dimension: (T, B, T_out, H', W', C') -> (B, T*T_out, H', W', C')
1266+ # First, swap to (B, T, T_out, H', W', C')
1267+ decoded_frames = jnp .swapaxes (decoded_frames , 0 , 1 )
1268+
1269+ # Then reshape to merge time dimensions
1270+ shape = decoded_frames .shape
1271+ decoded = decoded_frames .reshape (shape [0 ], shape [1 ] * shape [2 ], shape [3 ], shape [4 ], shape [5 ])
12211272
12221273 out = jnp .clip (decoded , min = - 1.0 , max = 1.0 )
12231274
0 commit comments