Skip to content

Commit f4de267

Browse files
committed
encode method corrected
1 parent 1f712f8 commit f4de267

1 file changed

Lines changed: 44 additions & 10 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,21 +1202,55 @@ def encode(
12021202
x = jnp.transpose(x, (0, 2, 3, 4, 1))
12031203
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
12041204

1205-
x_scan = jnp.swapaxes(x, 0, 1)
12061205
b, t, h, w, c = x.shape
1207-
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
1206+
all_outs = []
12081207

1209-
def scan_fn(carry, input_slice):
1210-
# Expand Time dimension for Conv3d
1211-
input_slice = jnp.expand_dims(input_slice, 1)
1208+
def scan_fn_chunk(carry, input_slice):
1209+
# input_slice shape is (B, H, W, C)
1210+
input_slice = jnp.expand_dims(input_slice, 1) # Shape (B, 1, H, W, C) for encoder
12121211
out_slice, new_carry = self.encoder(input_slice, carry)
1213-
# Squeeze Time dimension for scan stacking
1214-
out_slice = jnp.squeeze(out_slice, 1)
1212+
# out_slice shape is (B, 1, H', W', C')
1213+
out_slice = jnp.squeeze(out_slice, 1) # Shape (B, H', W', C') for scan output
12151214
return new_carry, out_slice
12161215

1217-
final_cache, encoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1218-
encoded = jnp.swapaxes(encoded_frames, 0, 1)
1219-
enc, _ = self.quant_conv(encoded)
1216+
# 1. Process the first frame
1217+
# Initialize cache for the first frame
1218+
init_cache_first = self.encoder.init_cache(b, h, w, x.dtype)
1219+
x_first_scan = jnp.expand_dims(x[:, 0, ...], axis=0) # Shape (1, B, H, W, C) for scan
1220+
1221+
_, out_first_frames = jax.lax.scan(scan_fn_chunk, init_cache_first, x_first_scan)
1222+
# out_first_frames shape is (1, B, H', W', C')
1223+
all_outs.append(jnp.swapaxes(out_first_frames, 0, 1)) # Shape (B, 1, H', W', C')
1224+
1225+
# 2. Process subsequent Chunks of 4
1226+
if t > 1:
1227+
num_chunks = (t - 1 + 3) // 4 # Ceiling division
1228+
for i in range(num_chunks):
1229+
start_idx = 1 + 4 * i
1230+
end_idx = min(start_idx + 4, t)
1231+
1232+
if start_idx >= t:
1233+
break
1234+
1235+
chunk = x[:, start_idx:end_idx, ...]
1236+
# Prepare chunk for scan: shape (T_chunk, B, H, W, C)
1237+
x_scan = jnp.swapaxes(chunk, 0, 1)
1238+
1239+
# *** Cache Reset for EACH CHUNK ***
1240+
init_cache_chunk = self.encoder.init_cache(b, h, w, x.dtype)
1241+
1242+
_, encoded_frames_chunk = jax.lax.scan(scan_fn_chunk, init_cache_chunk, x_scan)
1243+
# encoded_frames_chunk shape is (T_chunk, B, H', W', C')
1244+
1245+
# Transpose back to (B, T_chunk, H', W', C')
1246+
out_chunk = jnp.swapaxes(encoded_frames_chunk, 0, 1)
1247+
all_outs.append(out_chunk)
1248+
1249+
# Concatenate results from all chunks along the time dimension
1250+
encoded = jnp.concatenate(all_outs, axis=1)
1251+
1252+
# Apply quant_conv - this layer also has a cache, but the old code didn't pipe it.
1253+
enc, _ = self.quant_conv(encoded, cache_x=None)
12201254

12211255
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
12221256
h = jnp.concatenate([mu, logvar], axis=-1)

0 commit comments

Comments
 (0)