Skip to content

Commit 91d9a2a

Browse files
committed
full refactor
1 parent 4c28dc2 commit 91d9a2a

1 file changed

Lines changed: 8 additions & 66 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 8 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,8 +1165,7 @@ def encode(
11651165
self, x: jax.Array, feat_cache: Optional[AutoencoderKLWanCache] = None, return_dict: bool = True
11661166
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
11671167
"""
1168-
Encode video using jax.lax.scan for temporal iteration.
1169-
This enables proper JIT compilation while managing memory efficiently.
1168+
Encode video. Process the full video at once to handle temporal downsampling.
11701169
11711170
Args:
11721171
x: Input video tensor
@@ -1176,44 +1175,13 @@ def encode(
11761175
if x.shape[-1] != 3:
11771176
x = jnp.transpose(x, (0, 2, 3, 4, 1))
11781177

1179-
# Calculate temporal downsampling factor
1180-
temporal_downsample_factor = 1
1181-
for ds in self.temperal_downsample:
1182-
if ds:
1183-
temporal_downsample_factor *= 2
1184-
11851178
b, t, h, w, c = x.shape
11861179

1187-
# Process frames in chunks that match temporal downsampling
1188-
# This prevents frames from being downsampled to 0
1189-
chunk_size = temporal_downsample_factor
1190-
1191-
# Pad time dimension if needed to make it divisible by chunk_size
1192-
if t % chunk_size != 0:
1193-
pad_frames = chunk_size - (t % chunk_size)
1194-
x = jnp.pad(x, ((0, 0), (0, pad_frames), (0, 0), (0, 0), (0, 0)), mode='edge')
1195-
t = x.shape[1]
1196-
1197-
# Reshape to process chunks: (B, T, H, W, C) -> (T//chunk_size, B, chunk_size, H, W, C)
1198-
x_chunks = x.reshape(b, t // chunk_size, chunk_size, h, w, c)
1199-
x_scan = jnp.swapaxes(x_chunks, 0, 1) # -> (T//chunk_size, B, chunk_size, H, W, C)
1200-
1180+
# Initialize cache once
12011181
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
1202-
1203-
def scan_fn(carry, input_chunk):
1204-
"""Scan function processes one chunk of frames at a time."""
1205-
# input_chunk shape: (B, chunk_size, H, W, C)
1206-
out_chunk, new_carry = self.encoder(input_chunk, carry)
1207-
return new_carry, out_chunk
1208-
1209-
# Use jax.lax.scan for JIT-compilable temporal iteration
1210-
final_cache, encoded_chunks = jax.lax.scan(scan_fn, init_cache, x_scan)
1211-
# encoded_chunks shape: (T//chunk_size, B, T_out_per_chunk, H', W', C')
12121182

1213-
# Reshape back: (T//chunk_size, B, T_out, H', W', C') -> (B, T_total, H', W', C')
1214-
n_chunks, batch, t_per_chunk, h_out, w_out, c_out = encoded_chunks.shape
1215-
encoded = jnp.transpose(encoded_chunks, (1, 0, 2, 3, 4, 5)) # (B, n_chunks, T_out, H', W', C')
1216-
encoded = encoded.reshape(batch, n_chunks * t_per_chunk, h_out, w_out, c_out)
1183+
# Process the full video through encoder (handles temporal downsampling internally)
1184+
encoded, _ = self.encoder(x, init_cache)
12171185

12181186
# Apply quantization convolution
12191187
enc, _ = self.quant_conv(encoded)
@@ -1230,8 +1198,7 @@ def decode(
12301198
self, z: jax.Array, feat_cache: Optional[AutoencoderKLWanCache] = None, return_dict: bool = True
12311199
) -> Union[FlaxDecoderOutput, jax.Array]:
12321200
"""
1233-
Decode latents using jax.lax.scan for temporal iteration.
1234-
This enables proper JIT compilation while managing memory efficiently.
1201+
Decode latents. Process the full latent at once to handle temporal upsampling.
12351202
12361203
Args:
12371204
z: Latent tensor
@@ -1244,38 +1211,13 @@ def decode(
12441211
# Apply post-quantization convolution
12451212
x, _ = self.post_quant_conv(z)
12461213

1247-
# Calculate temporal upsampling factor
1248-
temporal_upsample_factor = 1
1249-
for us in self.temporal_upsample:
1250-
if us:
1251-
temporal_upsample_factor *= 2
1252-
12531214
b, t, h, w, c = x.shape
12541215

1255-
# For decoder, we still process one frame at a time but output will be upsampled
1256-
x_scan = jnp.swapaxes(x, 0, 1) # (B, T, H, W, C) -> (T, B, H, W, C)
1257-
1216+
# Initialize cache once
12581217
init_cache = self.decoder.init_cache(b, h, w, x.dtype)
12591218

1260-
def scan_fn(carry, input_slice):
1261-
"""Scan function processes one latent frame at a time."""
1262-
# Expand time dimension for Conv3d compatibility
1263-
input_slice = jnp.expand_dims(input_slice, 1) # (B, H, W, C) -> (B, 1, H, W, C)
1264-
# Use bfloat16 accumulation to save memory
1265-
out_slice, new_carry = self.decoder(input_slice, carry)
1266-
out_slice = out_slice.astype(jnp.bfloat16)
1267-
return new_carry, out_slice
1268-
1269-
# Use jax.lax.scan for JIT-compilable temporal iteration
1270-
final_cache, decoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1271-
1272-
# decoded_frames shape: (T_lat, B, T_upsample, H, W, C)
1273-
# Transpose to (B, T_lat, T_upsample, H, W, C)
1274-
decoded = jnp.transpose(decoded_frames, (1, 0, 2, 3, 4, 5))
1275-
1276-
# Reshape to (B, T_lat * T_upsample, H, W, C)
1277-
b, t_lat, t_sub, h, w, c = decoded.shape
1278-
decoded = decoded.reshape(b, t_lat * t_sub, h, w, c)
1219+
# Process the full latent through decoder (handles temporal upsampling internally)
1220+
decoded, _ = self.decoder(x, init_cache)
12791221

12801222
out = jnp.clip(decoded, min=-1.0, max=1.0)
12811223

0 commit comments

Comments
 (0)