Skip to content

Commit 4744e19

Browse files
committed
encode method fixed
1 parent 5b4d511 commit 4744e19

1 file changed

Lines changed: 63 additions & 28 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,34 +1197,69 @@ def __init__(
11971197
def encode(
11981198
self, x: jax.Array, return_dict: bool = True
11991199
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
1200-
if x.shape[-1] != 3:
1201-
# reshape channel last for JAX
1202-
x = jnp.transpose(x, (0, 2, 3, 4, 1))
1203-
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
1204-
1205-
x_scan = jnp.swapaxes(x, 0, 1)
1206-
b, t, h, w, c = x.shape
1207-
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
1208-
1209-
def scan_fn(carry, input_slice):
1210-
# Expand Time dimension for Conv3d
1211-
input_slice = jnp.expand_dims(input_slice, 1)
1212-
out_slice, new_carry = self.encoder(input_slice, carry)
1213-
# Squeeze Time dimension for scan stacking
1214-
out_slice = jnp.squeeze(out_slice, 1)
1215-
return new_carry, out_slice
1216-
1217-
final_cache, encoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1218-
encoded = jnp.swapaxes(encoded_frames, 0, 1)
1219-
enc, _ = self.quant_conv(encoded)
1220-
1221-
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
1222-
h = jnp.concatenate([mu, logvar], axis=-1)
1223-
1224-
posterior = FlaxDiagonalGaussianDistribution(h)
1225-
if not return_dict:
1226-
return (posterior,)
1227-
return FlaxAutoencoderKLOutput(latent_dist=posterior)
1200+
# 1. Standard Transpose Check (Matches Old)
1201+
if x.shape[-1] != 3:
1202+
x = jnp.transpose(x, (0, 2, 3, 4, 1))
1203+
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
1204+
1205+
b, t, h, w, c = x.shape
1206+
1207+
# 2. Replicate "First Frame" Logic (Matches Old 'if i == 0')
1208+
# The first frame is processed individually to prime the cache.
1209+
x_first = x[:, :1, ...] # Shape: (B, 1, H, W, C)
1210+
1211+
# 3. Replicate "Chunking" Logic (Matches Old 'else: chunks of 4')
1212+
# We take the remaining frames (Index 1 to End)
1213+
x_rest = x[:, 1:, ...] # Shape: (B, T-1, H, W, C)
1214+
1215+
# We assume the remaining frames are divisible by 4 (e.g. 80 frames)
1216+
# Reshape to (Num_Chunks, B, 4, H, W, C) for the scan loop
1217+
t_rest = t - 1
1218+
assert t_rest % 4 == 0, f"Remaining frames {t_rest} must be divisible by 4 (Total frames must be 1 + 4*k)"
1219+
num_chunks = t_rest // 4
1220+
1221+
# Prepare for scan: Swap axis 0 and 1 so 'num_chunks' is the scan iterator
1222+
x_chunks = x_rest.reshape(b, num_chunks, 4, h, w, c)
1223+
x_chunks = jnp.transpose(x_chunks, (1, 0, 2, 3, 4, 5))
1224+
1225+
# 4. Initialize Cache
1226+
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
1227+
1228+
# 5. Execute First Frame
1229+
# This corresponds to the 'i=0' iteration in the old loop.
1230+
enc_first, cache_after_first = self.encoder(x_first, init_cache)
1231+
1232+
# 6. Execute Scan on Chunks
1233+
# This corresponds to the 'i > 0' iterations in the old loop.
1234+
def scan_fn(carry, input_chunk):
1235+
# input_chunk is (B, 4, H, W, C).
1236+
# The encoder naturally consumes 4 frames and outputs 1 latent frame (due to stride)
1237+
out_chunk, new_carry = self.encoder(input_chunk, carry)
1238+
return new_carry, out_chunk
1239+
1240+
final_cache, enc_rest_chunks = jax.lax.scan(scan_fn, cache_after_first, x_chunks)
1241+
1242+
# 7. Flatten and Reassemble
1243+
# enc_rest_chunks: (Num_Chunks, B, T_latent_chunk, ...)
1244+
# We swap back to (B, Num_Chunks, ...) and flatten
1245+
enc_rest_chunks = jnp.swapaxes(enc_rest_chunks, 0, 1)
1246+
1247+
# Flatten the chunks into a continuous sequence
1248+
b_out, n_chunks, t_chunk_out, h_out, w_out, c_out = enc_rest_chunks.shape
1249+
enc_rest = enc_rest_chunks.reshape(b_out, n_chunks * t_chunk_out, h_out, w_out, c_out)
1250+
1251+
# Concatenate: [First Frame Result] + [Rest of Frames Result]
1252+
encoded = jnp.concatenate([enc_first, enc_rest], axis=1)
1253+
1254+
# 8. Post-Processing (Matches Old Logic exactly)
1255+
enc, _ = self.quant_conv(encoded)
1256+
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
1257+
h_latents = jnp.concatenate([mu, logvar], axis=-1)
1258+
1259+
posterior = FlaxDiagonalGaussianDistribution(h_latents)
1260+
if not return_dict:
1261+
return (posterior,)
1262+
return FlaxAutoencoderKLOutput(latent_dist=posterior)
12281263

12291264
@nnx.jit
12301265
def decode(

0 commit comments

Comments
 (0)