Skip to content

Commit 4fdabf7

Browse files
committed
encode method fixed
1 parent 4744e19 commit 4fdabf7

1 file changed

Lines changed: 51 additions & 36 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,61 +1197,76 @@ def __init__(
11971197
def encode(
11981198
self, x: jax.Array, return_dict: bool = True
11991199
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
1200-
# 1. Standard Transpose Check (Matches Old)
12011200
if x.shape[-1] != 3:
12021201
x = jnp.transpose(x, (0, 2, 3, 4, 1))
1203-
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
1204-
1202+
12051203
b, t, h, w, c = x.shape
1206-
1207-
# 2. Replicate "First Frame" Logic (Matches Old 'if i == 0')
1208-
# The first frame is processed individually to prime the cache.
1209-
x_first = x[:, :1, ...] # Shape: (B, 1, H, W, C)
12101204

1211-
# 3. Replicate "Chunking" Logic (Matches Old 'else: chunks of 4')
1212-
# We take the remaining frames (Index 1 to End)
1213-
x_rest = x[:, 1:, ...] # Shape: (B, T-1, H, W, C)
1205+
# --- STEP 1: Process First Frame (With Padding Hack) ---
1206+
x_first = x[:, :1, ...] # (B, 1, ...)
12141207

1215-
# We assume the remaining frames are divisible by 4 (e.g. 80 frames)
1216-
# Reshape to (Num_Chunks, B, 4, H, W, C) for the scan loop
1217-
t_rest = t - 1
1218-
assert t_rest % 4 == 0, f"Remaining frames {t_rest} must be divisible by 4 (Total frames must be 1 + 4*k)"
1219-
num_chunks = t_rest // 4
1208+
# We PAD this single frame to T=4 so it survives the strides.
1209+
# We will take only the first result.
1210+
x_first_padded = jnp.concatenate([x_first] * 4, axis=1) # (B, 4, ...)
12201211

1221-
# Prepare for scan: Swap axis 0 and 1 so 'num_chunks' is the scan iterator
1222-
x_chunks = x_rest.reshape(b, num_chunks, 4, h, w, c)
1223-
x_chunks = jnp.transpose(x_chunks, (1, 0, 2, 3, 4, 5))
1224-
1225-
# 4. Initialize Cache
1212+
# Initialize Cache
12261213
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
1214+
1215+
# Run Encoder on padded first frame
1216+
# We discard the cache update here because this is a "Fake" run to get the latent
1217+
# BUT wait, we need the cache state for the next frames.
1218+
# This is tricky. If we pad [0, 0, 0, 0], the cache will be filled with Frame 0's history.
1219+
# This is actually correct for a static image or start of video.
1220+
1221+
enc_first_padded, cache_after_first = self.encoder(x_first_padded, init_cache)
1222+
1223+
# Take only the first frame of the output
1224+
enc_first = enc_first_padded[:, :1, ...]
12271225

1228-
# 5. Execute First Frame
1229-
# This corresponds to the 'i=0' iteration in the old loop.
1230-
enc_first, cache_after_first = self.encoder(x_first, init_cache)
1226+
# --- STEP 2: Process Rest of Frames (Chunks of 4) ---
1227+
x_rest = x[:, 1:, ...]
1228+
t_rest = t - 1
1229+
1230+
# Pad remainder to be divisible by 4
1231+
pad_len = (4 - (t_rest % 4)) % 4
1232+
if pad_len > 0:
1233+
last = x_rest[:, -1:, ...]
1234+
padding = jnp.repeat(last, pad_len, axis=1)
1235+
x_rest_padded = jnp.concatenate([x_rest, padding], axis=1)
1236+
else:
1237+
x_rest_padded = x_rest
1238+
1239+
num_chunks = x_rest_padded.shape[1] // 4
1240+
x_chunks = x_rest_padded.reshape(b, num_chunks, 4, h, w, c)
1241+
x_chunks = jnp.transpose(x_chunks, (1, 0, 2, 3, 4, 5))
12311242

1232-
# 6. Execute Scan on Chunks
1233-
# This corresponds to the 'i > 0' iterations in the old loop.
1243+
# Scan Function
12341244
def scan_fn(carry, input_chunk):
1235-
# input_chunk is (B, 4, H, W, C).
1236-
# The encoder naturally consumes 4 frames and outputs 1 latent frame (due to stride)
12371245
out_chunk, new_carry = self.encoder(input_chunk, carry)
12381246
return new_carry, out_chunk
12391247

12401248
final_cache, enc_rest_chunks = jax.lax.scan(scan_fn, cache_after_first, x_chunks)
12411249

1242-
# 7. Flatten and Reassemble
1243-
# enc_rest_chunks: (Num_Chunks, B, T_latent_chunk, ...)
1244-
# We swap back to (B, Num_Chunks, ...) and flatten
1250+
# Flatten Rest
12451251
enc_rest_chunks = jnp.swapaxes(enc_rest_chunks, 0, 1)
1252+
b_out, n_chunks, t_chunk, h_out, w_out, c_out = enc_rest_chunks.shape
1253+
enc_rest = enc_rest_chunks.reshape(b_out, n_chunks * t_chunk, h_out, w_out, c_out)
12461254

1247-
# Flatten the chunks into a continuous sequence
1248-
b_out, n_chunks, t_chunk_out, h_out, w_out, c_out = enc_rest_chunks.shape
1249-
enc_rest = enc_rest_chunks.reshape(b_out, n_chunks * t_chunk_out, h_out, w_out, c_out)
1250-
1251-
# Concatenate: [First Frame Result] + [Rest of Frames Result]
1255+
# Slice off padding from result if needed
1256+
# We padded input by 'pad_len'. Output is downsampled by 4 (likely).
1257+
# Actually, since we chunked by 4 and got 1 output, the mapping is 1-to-1 chunk-to-latent.
1258+
# If we added 1 chunk of padding, we remove 1 frame of output.
1259+
if pad_len > 0:
1260+
# We padded inputs. Does that mean we generated extra latents?
1261+
# If t_rest=5. Pad to 8. Chunks=2. Output=2 latents.
1262+
# Real latents needed: ceil(5/4) = 2.
1263+
# So actually, we don't need to slice! The ceiling behavior is what we want.
1264+
pass
1265+
1266+
# Concatenate
12521267
encoded = jnp.concatenate([enc_first, enc_rest], axis=1)
12531268

1254-
# 8. Post-Processing (Matches Old Logic exactly)
1269+
# Quantize
12551270
enc, _ = self.quant_conv(encoded)
12561271
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
12571272
h_latents = jnp.concatenate([mu, logvar], axis=-1)

0 commit comments

Comments
 (0)