@@ -1165,8 +1165,7 @@ def encode(
11651165 self , x : jax .Array , feat_cache : Optional [AutoencoderKLWanCache ] = None , return_dict : bool = True
11661166 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
11671167 """
1168- Encode video using jax.lax.scan for temporal iteration.
1169- This enables proper JIT compilation while managing memory efficiently.
1168+ Encode video. Process the full video at once to handle temporal downsampling.
11701169
11711170 Args:
11721171 x: Input video tensor
@@ -1176,44 +1175,13 @@ def encode(
11761175 if x .shape [- 1 ] != 3 :
11771176 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
11781177
1179- # Calculate temporal downsampling factor
1180- temporal_downsample_factor = 1
1181- for ds in self .temperal_downsample :
1182- if ds :
1183- temporal_downsample_factor *= 2
1184-
11851178 b , t , h , w , c = x .shape
11861179
1187- # Process frames in chunks that match temporal downsampling
1188- # This prevents frames from being downsampled to 0
1189- chunk_size = temporal_downsample_factor
1190-
1191- # Pad time dimension if needed to make it divisible by chunk_size
1192- if t % chunk_size != 0 :
1193- pad_frames = chunk_size - (t % chunk_size )
1194- x = jnp .pad (x , ((0 , 0 ), (0 , pad_frames ), (0 , 0 ), (0 , 0 ), (0 , 0 )), mode = 'edge' )
1195- t = x .shape [1 ]
1196-
1197- # Reshape to process chunks: (B, T, H, W, C) -> (T//chunk_size, B, chunk_size, H, W, C)
1198- x_chunks = x .reshape (b , t // chunk_size , chunk_size , h , w , c )
1199- x_scan = jnp .swapaxes (x_chunks , 0 , 1 ) # -> (T//chunk_size, B, chunk_size, H, W, C)
1200-
1180+ # Initialize cache once
12011181 init_cache = self .encoder .init_cache (b , h , w , x .dtype )
1202-
1203- def scan_fn (carry , input_chunk ):
1204- """Scan function processes one chunk of frames at a time."""
1205- # input_chunk shape: (B, chunk_size, H, W, C)
1206- out_chunk , new_carry = self .encoder (input_chunk , carry )
1207- return new_carry , out_chunk
1208-
1209- # Use jax.lax.scan for JIT-compilable temporal iteration
1210- final_cache , encoded_chunks = jax .lax .scan (scan_fn , init_cache , x_scan )
1211- # encoded_chunks shape: (T//chunk_size, B, T_out_per_chunk, H', W', C')
12121182
1213- # Reshape back: (T//chunk_size, B, T_out, H', W', C') -> (B, T_total, H', W', C')
1214- n_chunks , batch , t_per_chunk , h_out , w_out , c_out = encoded_chunks .shape
1215- encoded = jnp .transpose (encoded_chunks , (1 , 0 , 2 , 3 , 4 , 5 )) # (B, n_chunks, T_out, H', W', C')
1216- encoded = encoded .reshape (batch , n_chunks * t_per_chunk , h_out , w_out , c_out )
1183+ # Process the full video through encoder (handles temporal downsampling internally)
1184+ encoded , _ = self .encoder (x , init_cache )
12171185
12181186 # Apply quantization convolution
12191187 enc , _ = self .quant_conv (encoded )
@@ -1230,8 +1198,7 @@ def decode(
12301198 self , z : jax .Array , feat_cache : Optional [AutoencoderKLWanCache ] = None , return_dict : bool = True
12311199 ) -> Union [FlaxDecoderOutput , jax .Array ]:
12321200 """
1233- Decode latents using jax.lax.scan for temporal iteration.
1234- This enables proper JIT compilation while managing memory efficiently.
1201+ Decode latents. Process the full latent at once to handle temporal upsampling.
12351202
12361203 Args:
12371204 z: Latent tensor
@@ -1244,38 +1211,13 @@ def decode(
12441211 # Apply post-quantization convolution
12451212 x , _ = self .post_quant_conv (z )
12461213
1247- # Calculate temporal upsampling factor
1248- temporal_upsample_factor = 1
1249- for us in self .temporal_upsample :
1250- if us :
1251- temporal_upsample_factor *= 2
1252-
12531214 b , t , h , w , c = x .shape
12541215
1255- # For decoder, we still process one frame at a time but output will be upsampled
1256- x_scan = jnp .swapaxes (x , 0 , 1 ) # (B, T, H, W, C) -> (T, B, H, W, C)
1257-
1216+ # Initialize cache once
12581217 init_cache = self .decoder .init_cache (b , h , w , x .dtype )
12591218
1260- def scan_fn (carry , input_slice ):
1261- """Scan function processes one latent frame at a time."""
1262- # Expand time dimension for Conv3d compatibility
1263- input_slice = jnp .expand_dims (input_slice , 1 ) # (B, H, W, C) -> (B, 1, H, W, C)
1264- # Use bfloat16 accumulation to save memory
1265- out_slice , new_carry = self .decoder (input_slice , carry )
1266- out_slice = out_slice .astype (jnp .bfloat16 )
1267- return new_carry , out_slice
1268-
1269- # Use jax.lax.scan for JIT-compilable temporal iteration
1270- final_cache , decoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
1271-
1272- # decoded_frames shape: (T_lat, B, T_upsample, H, W, C)
1273- # Transpose to (B, T_lat, T_upsample, H, W, C)
1274- decoded = jnp .transpose (decoded_frames , (1 , 0 , 2 , 3 , 4 , 5 ))
1275-
1276- # Reshape to (B, T_lat * T_upsample, H, W, C)
1277- b , t_lat , t_sub , h , w , c = decoded .shape
1278- decoded = decoded .reshape (b , t_lat * t_sub , h , w , c )
1219+ # Process the full latent through decoder (handles temporal upsampling internally)
1220+ decoded , _ = self .decoder (x , init_cache )
12791221
12801222 out = jnp .clip (decoded , min = - 1.0 , max = 1.0 )
12811223
0 commit comments