Skip to content

Commit 8430443

Browse files
committed
full refactor
1 parent 91d9a2a commit 8430443

1 file changed

Lines changed: 62 additions & 11 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,10 +1165,11 @@ def encode(
11651165
self, x: jax.Array, feat_cache: Optional[AutoencoderKLWanCache] = None, return_dict: bool = True
11661166
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
11671167
"""
1168-
Encode video. Process the full video at once to handle temporal downsampling.
1168+
Encode video using jax.lax.scan for memory-efficient frame-by-frame processing.
1169+
Uses caching to maintain temporal context across frames.
11691170
11701171
Args:
1171-
x: Input video tensor
1172+
x: Input video tensor (B, T, H, W, C)
11721173
feat_cache: Cache object (for API compatibility, not used internally)
11731174
return_dict: Whether to return FlaxAutoencoderKLOutput or tuple
11741175
"""
@@ -1177,11 +1178,36 @@ def encode(
11771178

11781179
b, t, h, w, c = x.shape
11791180

1180-
# Initialize cache once
1181+
# Initialize cache for encoder
11811182
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
11821183

1183-
# Process the full video through encoder (handles temporal downsampling internally)
1184-
encoded, _ = self.encoder(x, init_cache)
1184+
# Prepare for scan: swap batch and time to iterate over time
1185+
x_scan = jnp.swapaxes(x, 0, 1) # (T, B, H, W, C)
1186+
1187+
def scan_fn(carry, x_frame):
1188+
"""Process one frame at a time with cache."""
1189+
# x_frame shape: (B, H, W, C)
1190+
# Add time dimension for processing
1191+
x_frame = jnp.expand_dims(x_frame, axis=1) # (B, 1, H, W, C)
1192+
1193+
# Process through encoder with cache
1194+
out_frame, new_cache = self.encoder(x_frame, carry)
1195+
1196+
# out_frame shape: (B, T_out, H', W', C') where T_out depends on temporal downsampling
1197+
# For stability, we keep the time dimension
1198+
return new_cache, out_frame
1199+
1200+
# Scan over time dimension
1201+
final_cache, encoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1202+
# encoded_frames shape: (T, B, T_out_per_frame, H', W', C')
1203+
1204+
# Concatenate along time dimension: (T, B, T_out, H', W', C') -> (B, T*T_out, H', W', C')
1205+
# First, swap to (B, T, T_out, H', W', C')
1206+
encoded_frames = jnp.swapaxes(encoded_frames, 0, 1)
1207+
1208+
# Then reshape to merge time dimensions
1209+
shape = encoded_frames.shape
1210+
encoded = encoded_frames.reshape(shape[0], shape[1] * shape[2], shape[3], shape[4], shape[5])
11851211

11861212
# Apply quantization convolution
11871213
enc, _ = self.quant_conv(encoded)
@@ -1198,10 +1224,11 @@ def decode(
11981224
self, z: jax.Array, feat_cache: Optional[AutoencoderKLWanCache] = None, return_dict: bool = True
11991225
) -> Union[FlaxDecoderOutput, jax.Array]:
12001226
"""
1201-
Decode latents. Process the full latent at once to handle temporal upsampling.
1227+
Decode latents using jax.lax.scan for memory-efficient frame-by-frame processing.
1228+
Uses caching to maintain temporal context across frames.
12021229
12031230
Args:
1204-
z: Latent tensor
1231+
z: Latent tensor (B, T, H, W, C)
12051232
feat_cache: Cache object (for API compatibility, not used internally)
12061233
return_dict: Whether to return FlaxDecoderOutput or tuple
12071234
"""
@@ -1213,11 +1240,35 @@ def decode(
12131240

12141241
b, t, h, w, c = x.shape
12151242

1216-
# Initialize cache once
1243+
# Initialize cache for decoder
12171244
init_cache = self.decoder.init_cache(b, h, w, x.dtype)
1218-
1219-
# Process the full latent through decoder (handles temporal upsampling internally)
1220-
decoded, _ = self.decoder(x, init_cache)
1245+
1246+
# Prepare for scan: swap batch and time to iterate over time
1247+
x_scan = jnp.swapaxes(x, 0, 1) # (T, B, H, W, C)
1248+
1249+
def scan_fn(carry, x_frame):
1250+
"""Process one latent frame at a time with cache."""
1251+
# x_frame shape: (B, H, W, C)
1252+
# Add time dimension for processing
1253+
x_frame = jnp.expand_dims(x_frame, axis=1) # (B, 1, H, W, C)
1254+
1255+
# Process through decoder with cache (will upsample temporally)
1256+
out_frame, new_cache = self.decoder(x_frame, carry)
1257+
1258+
# out_frame shape: (B, T_out, H', W', C') where T_out depends on temporal upsampling
1259+
return new_cache, out_frame
1260+
1261+
# Scan over time dimension
1262+
final_cache, decoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1263+
# decoded_frames shape: (T, B, T_out_per_frame, H', W', C')
1264+
1265+
# Concatenate along time dimension: (T, B, T_out, H', W', C') -> (B, T*T_out, H', W', C')
1266+
# First, swap to (B, T, T_out, H', W', C')
1267+
decoded_frames = jnp.swapaxes(decoded_frames, 0, 1)
1268+
1269+
# Then reshape to merge time dimensions
1270+
shape = decoded_frames.shape
1271+
decoded = decoded_frames.reshape(shape[0], shape[1] * shape[2], shape[3], shape[4], shape[5])
12211272

12221273
out = jnp.clip(decoded, min=-1.0, max=1.0)
12231274

0 commit comments

Comments
 (0)