Skip to content

Commit 15fe583

Browse files
committed
encode decode modified
1 parent b0e4851 commit 15fe583

1 file changed

Lines changed: 60 additions & 38 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,49 +1258,79 @@ def __init__(
12581258
precision=precision,
12591259
)
12601260

1261-
def _encode_jit(self, x: jax.Array) -> jax.Array:
1262-
"""Core computation part to be JIT-compiled."""
1261+
@nnx.jit
1262+
def encode(
1263+
self, x: jax.Array, return_dict: bool = True
1264+
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
12631265
if x.shape[-1] != 3:
1264-
# reshape channel last for JAX
12651266
x = jnp.transpose(x, (0, 2, 3, 4, 1))
1266-
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
1267+
assert x.shape[-1] == 3, "Input channels must be 3"
12671268

1268-
x_scan = jnp.swapaxes(x, 0, 1)
12691269
b, t, h, w, c = x.shape
1270-
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
1270+
chunk_size = 4 # Process in chunks of 4 frames
12711271

1272-
def scan_fn(carry, input_slice):
1273-
# Expand Time dimension for Conv3d
1274-
input_slice = jnp.expand_dims(input_slice, 1)
1275-
out_slice, new_carry = self.encoder(input_slice, carry)
1276-
# Squeeze Time dimension for scan stacking
1277-
out_slice = jnp.squeeze(out_slice, 1)
1278-
return new_carry, out_slice
1272+
# Calculate padding needed to make the time dimension a multiple of chunk_size
1273+
num_chunks = math.ceil(t / chunk_size)
1274+
padded_t = num_chunks * chunk_size
1275+
padding_t = padded_t - t
12791276

1280-
final_cache, encoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1281-
encoded = jnp.swapaxes(encoded_frames, 0, 1)
1282-
enc, _ = self.quant_conv(encoded)
1277+
if padding_t > 0:
1278+
paddings = [(0, 0)] * x.ndim
1279+
paddings[1] = (0, padding_t) # Pad at the end of the time dimension
1280+
x_padded = jnp.pad(x, paddings, mode='constant', constant_values=0.0)
1281+
else:
1282+
x_padded = x
12831283

1284-
# h contains the parameters for the distribution
1285-
h = enc # Or jnp.concatenate([mu, logvar], axis=-1) as originally
1286-
return h
1287-
_encode_compiled = nnx.jit(_encode_jit)
1284+
# Reshape for scan: (B, Num_Chunks, Chunk_T, H, W, C)
1285+
x_reshaped = x_padded.reshape((b, num_chunks, chunk_size, h, w, c))
12881286

1289-
def encode(
1290-
self, x: jax.Array, return_dict: bool = True
1291-
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
1292-
"""Encodes the input, returning standard distribution objects."""
1293-
# Call the compiled function to get JAX arrays
1294-
h = self._encode_compiled(x)
1287+
# Swap axes for scan: (Num_Chunks, B, Chunk_T, H, W, C)
1288+
x_scannable = jnp.swapaxes(x_reshaped, 0, 1)
1289+
1290+
# Define the function to be executed in each step of the scan
1291+
def scan_fn(dummy_carry, x_chunk):
1292+
# x_chunk shape: (B, chunk_size, H, W, C)
1293+
b_c, _, h_c, w_c, _ = x_chunk.shape
1294+
1295+
# Reset cache for each chunk to ensure independence
1296+
init_cache = self.encoder.init_cache(b_c, h_c, w_c, x_chunk.dtype)
1297+
1298+
# Use gradient checkpointing to save memory
1299+
out_chunk, _ = nnx.checkpoint(self.encoder)(x_chunk, init_cache)
1300+
# Expected out_chunk shape: (B, 1, H', W', Z*2)
1301+
# as each 4-frame chunk is downsampled temporally by 4x.
1302+
1303+
return dummy_carry, out_chunk
1304+
1305+
# Initial carry for scan - not used for state propagation between chunks
1306+
initial_scan_carry = self.encoder.init_cache(b, h, w, x.dtype)
1307+
1308+
# Run the scan over the chunks
1309+
_, encoded_chunks = jax.lax.scan(scan_fn, initial_scan_carry, x_scannable)
1310+
# encoded_chunks shape: (num_chunks, B, 1, H', W', Z*2)
1311+
1312+
# Concatenate the results from each chunk
1313+
# Transpose back to (B, num_chunks, 1, H', W', Z*2)
1314+
encoded_combined = jnp.swapaxes(encoded_chunks, 0, 1)
1315+
1316+
# Reshape to (B, num_chunks, H', W', Z*2)
1317+
b_out, nc_out, t_out_chunk, h_out, w_out, c_out = encoded_combined.shape
1318+
encoded = encoded_combined.reshape((b_out, nc_out * t_out_chunk, h_out, w_out, c_out))
1319+
# Final 'encoded' shape: (B, num_chunks, H', W', Z*2)
1320+
# For T=9, num_chunks=3. This matches the expected (B, 3, H', W', Z*2)
1321+
1322+
# Post-processing to get distribution parameters
1323+
enc, _ = self.quant_conv(encoded, cache_x=None)
1324+
mu = enc[..., :self.z_dim]
1325+
logvar = enc[..., self.z_dim:]
1326+
h = jnp.concatenate([mu, logvar], axis=-1)
12951327

1296-
# Create custom objects outside the JIT scope
12971328
posterior = FlaxDiagonalGaussianDistribution(h)
12981329

12991330
if not return_dict:
1300-
return (posterior,)
1331+
return (posterior,)
13011332
return FlaxAutoencoderKLOutput(latent_dist=posterior)
13021333

1303-
13041334
@nnx.jit
13051335
def decode(
13061336
self, z: jax.Array, return_dict: bool = True
@@ -1315,22 +1345,14 @@ def decode(
13151345
init_cache = self.decoder.init_cache(b, h, w, x.dtype)
13161346

13171347
def scan_fn(carry, input_slice):
1318-
# Expand Time dimension for Conv3d
13191348
input_slice = jnp.expand_dims(input_slice, 1)
1320-
# OPTIMIZATION: Force bfloat16 accumulation within the scan
1321-
# to save memory on the massive output buffer
13221349
out_slice, new_carry = self.decoder(input_slice, carry)
1323-
# out_slice = out_slice.astype(jnp.bfloat16)
13241350
return new_carry, out_slice
13251351

1326-
final_cache, decoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1352+
final_cache, decoded_frames = jax.lax.scan(scan_fn, initial_scan_carry, x_scan)
13271353

1328-
# decoded_frames shape: (T_lat, B, 4, H, W, C)
1329-
# We need to flatten T_lat and 4.
1330-
# Transpose to (B, T_lat, 4, H, W, C)
13311354
decoded = jnp.transpose(decoded_frames, (1, 0, 2, 3, 4, 5))
13321355

1333-
# Reshape to (B, T_lat*4, H, W, C)
13341356
b, t_lat, t_sub, h, w, c = decoded.shape
13351357
decoded = decoded.reshape(b, t_lat * t_sub, h, w, c)
13361358

0 commit comments

Comments
 (0)