Skip to content

Commit 081feff

Browse files
committed
modify encode
1 parent e154cf2 commit 081feff

1 file changed

Lines changed: 15 additions & 46 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 15 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,65 +1247,34 @@ def __init__(
12471247
precision=precision,
12481248
)
12491249

1250-
# @nnx.jit
1250+
# REMOVE @nnx.jit for now to ensure this logic runs
12511251
def encode(
12521252
self, x: jax.Array, return_dict: bool = True
12531253
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
12541254
if x.shape[-1] != 3:
1255-
# reshape channel last for JAX
12561255
x = jnp.transpose(x, (0, 2, 3, 4, 1))
12571256
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
12581257

12591258
b, t, h, w, c = x.shape
1260-
all_outs = []
1261-
1262-
def scan_fn_chunk(carry, input_slice):
1263-
# input_slice shape is (B, H, W, C)
1264-
input_slice = jnp.expand_dims(input_slice, 1) # Shape (B, 1, H, W, C) for encoder
1265-
out_slice, new_carry = self.encoder(input_slice, carry)
1266-
# out_slice shape is (B, 1, H', W', C')
1267-
out_slice = jnp.squeeze(out_slice, 1) # Shape (B, H', W', C') for scan output
1268-
return new_carry, out_slice
1269-
1270-
# 1. Process the first frame
1271-
# Initialize cache for the first frame
1272-
init_cache_first = self.encoder.init_cache(b, h, w, x.dtype)
1273-
x_first_scan = jnp.expand_dims(x[:, 0, ...], axis=0) # Shape (1, B, H, W, C) for scan
1259+
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
12741260

1275-
_, out_first_frames = jax.lax.scan(scan_fn_chunk, init_cache_first, x_first_scan)
1276-
# out_first_frames shape is (1, B, H', W', C')
1277-
all_outs.append(jnp.swapaxes(out_first_frames, 0, 1)) # Shape (B, 1, H', W', C')
1261+
# Process first frame
1262+
out1, _ = self.encoder(x[:, :1, ...], init_cache)
12781263

1279-
# 2. Process subsequent Chunks of 4
12801264
if t > 1:
1281-
num_chunks = (t - 1 + 3) // 4 # Ceiling division
1282-
for i in range(num_chunks):
1283-
start_idx = 1 + 4 * i
1284-
end_idx = min(start_idx + 4, t)
1285-
1286-
if start_idx >= t:
1287-
break
1288-
1289-
chunk = x[:, start_idx:end_idx, ...]
1290-
# Prepare chunk for scan: shape (T_chunk, B, H, W, C)
1291-
x_scan = jnp.swapaxes(chunk, 0, 1)
1292-
1293-
# *** Cache Reset for EACH CHUNK ***
1294-
init_cache_chunk = self.encoder.init_cache(b, h, w, x.dtype)
1295-
1296-
_, encoded_frames_chunk = jax.lax.scan(scan_fn_chunk, init_cache_chunk, x_scan)
1297-
# encoded_frames_chunk shape is (T_chunk, B, H', W', C')
1298-
1299-
# Transpose back to (B, T_chunk, H', W', C')
1300-
out_chunk = jnp.swapaxes(encoded_frames_chunk, 0, 1)
1301-
all_outs.append(out_chunk)
1302-
1303-
# Concatenate results from all chunks along the time dimension
1304-
encoded = jnp.concatenate(all_outs, axis=1)
1265+
# Process remaining frames in one chunk
1266+
# We need to manage cache updates manually if not using scan
1267+
# This part is tricky because the new encoder returns cache,
1268+
# but the old logic didn't seem to carry cache between chunks.
1269+
1270+
# Let's SIMPLIFY to match the OLD logic's spirit: Reset cache for the chunk
1271+
init_cache_rest = self.encoder.init_cache(b, h, w, x.dtype)
1272+
out_rest, _ = self.encoder(x[:, 1:, ...], init_cache_rest)
1273+
encoded = jnp.concatenate([out1, out_rest], axis=1)
1274+
else:
1275+
encoded = out1
13051276

1306-
# Apply quant_conv - this layer also has a cache, but the old code didn't pipe it.
13071277
enc, _ = self.quant_conv(encoded, cache_x=None)
1308-
13091278
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
13101279
h = jnp.concatenate([mu, logvar], axis=-1)
13111280

0 commit comments

Comments
 (0)