Skip to content

Commit b1f1b6d

Browse files
committed
updated encode decode methods
1 parent b7b6730 commit b1f1b6d

1 file changed

Lines changed: 78 additions & 66 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 78 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,51 +1263,52 @@ def _encode_jit(self, x: jax.Array) -> jax.Array:
12631263
"""Contains the core JAX computations for encoding, suitable for JIT."""
12641264
if x.shape[-1] != 3:
12651265
x = jnp.transpose(x, (0, 2, 3, 4, 1))
1266-
# assert x.shape[-1] == 3, "Input channels must be 3" # Assertions might not be ideal in JIT
12671266

12681267
b, t, h, w, c = x.shape
1269-
chunk_size = 4 # Process in chunks of 4 frames
1270-
1271-
num_chunks = math.ceil(t / chunk_size)
1272-
padded_t = num_chunks * chunk_size
1273-
padding_t = padded_t - t
1274-
1275-
if padding_t > 0:
1276-
# Pad the time dimension to be a multiple of chunk_size
1277-
paddings = [(0, 0)] * x.ndim
1278-
paddings[1] = (0, padding_t) # Pad at the end of the time dimension
1279-
x_padded = jnp.pad(x, paddings, mode='constant', constant_values=0.0)
1280-
else:
1281-
x_padded = x
1282-
1283-
# Reshape for scan: (B, Num_Chunks, Chunk_T, H, W, C)
1284-
x_reshaped = x_padded.reshape((b, num_chunks, chunk_size, h, w, c))
1285-
1286-
# Swap axes for scan: (Num_Chunks, B, Chunk_T, H, W, C)
1287-
x_scannable = jnp.swapaxes(x_reshaped, 0, 1)
1268+
all_outs = []
12881269

1270+
# Process the first frame (Time=1)
1271+
x_first = x[:, :1, ...]
1272+
init_cache_first = self.encoder.init_cache(b, h, w, x_first.dtype)
12891273
encoder_checkpointed = jax.checkpoint(self.encoder)
1274+
out1, state_carry = encoder_checkpointed(x_first, init_cache_first)
1275+
all_outs.append(out1)
1276+
1277+
# Process the remaining frames using scan over chunks of 4
1278+
if t > 1:
1279+
x_rest = x[:, 1:, ...]
1280+
t_rest = x_rest.shape[1]
1281+
chunk_size = 4
1282+
1283+
num_chunks = math.ceil(t_rest / chunk_size)
1284+
padded_t_rest = num_chunks * chunk_size
1285+
padding_t = padded_t_rest - t_rest
1286+
1287+
if padding_t > 0:
1288+
paddings = [(0, 0)] * x_rest.ndim
1289+
paddings[1] = (0, padding_t) # Pad at the end
1290+
x_rest_padded = jnp.pad(x_rest, paddings, mode='constant', constant_values=0.0)
1291+
else:
1292+
x_rest_padded = x_rest
12901293

1291-
def scan_fn(dummy_carry, x_chunk):
1292-
# x_chunk shape: (B, chunk_size, H, W, C)
1293-
b_c, _, h_c, w_c, _ = x_chunk.shape
1294-
init_cache = self.encoder.init_cache(b_c, h_c, w_c, x_chunk.dtype)
1295-
out_chunk, _ = encoder_checkpointed(x_chunk, init_cache)
1296-
return dummy_carry, out_chunk
1294+
x_reshaped = x_rest_padded.reshape((b, num_chunks, chunk_size, h, w, c))
1295+
x_scannable = jnp.swapaxes(x_reshaped, 0, 1)
12971296

1298-
initial_scan_carry = {}
1299-
_, encoded_chunks = jax.lax.scan(scan_fn, initial_scan_carry, x_scannable)
1297+
def scan_fn(carry_state, x_chunk):
1298+
out_chunk, new_state = encoder_checkpointed(x_chunk, carry_state)
1299+
return new_state, out_chunk
13001300

1301-
encoded_combined = jnp.swapaxes(encoded_chunks, 0, 1)
1301+
_, encoded_chunks = jax.lax.scan(scan_fn, state_carry, x_scannable)
1302+
encoded_rest = jnp.swapaxes(encoded_chunks, 0, 1)
1303+
b_out, nc_out, t_out_chunk, h_out, w_out, c_out = encoded_rest.shape
1304+
encoded_rest = encoded_rest.reshape((b_out, nc_out * t_out_chunk, h_out, w_out, c_out))
13021305

1303-
b_out, nc_out, t_out_chunk, h_out, w_out, c_out = encoded_combined.shape
1304-
encoded = encoded_combined.reshape((b_out, nc_out * t_out_chunk, h_out, w_out, c_out))
1306+
all_outs.append(encoded_rest)
1307+
1308+
encoded = jnp.concatenate(all_outs, axis=1)
13051309

13061310
enc, _ = self.quant_conv(encoded, cache_x=None)
1307-
# mu = enc[..., :self.z_dim]
1308-
# logvar = enc[..., self.z_dim:]
1309-
# h = jnp.concatenate([mu, logvar], axis=-1)
1310-
return enc # Return the direct output of quant_conv
1311+
return enc
13111312

13121313
# JIT compile the internal JAX-based function
13131314
_encode_compiled = nnx.jit(_encode_jit)
@@ -1316,48 +1317,59 @@ def encode(
13161317
self, x: jax.Array, return_dict: bool = True
13171318
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
13181319
"""Encodes the input array and returns custom distribution objects."""
1319-
if x.shape[-1] != 3:
1320-
# Transpose in the non-JIT part if needed, though _encode_jit handles it too
1321-
pass # Handled inside _encode_jit
1322-
1323-
# Call the JIT-compiled function to get the raw encoded array
13241320
h_params = self._encode_compiled(x)
1325-
1326-
# Create the custom Python objects from the JAX array results
13271321
posterior = FlaxDiagonalGaussianDistribution(h_params)
1328-
13291322
if not return_dict:
13301323
return (posterior,)
13311324
return FlaxAutoencoderKLOutput(latent_dist=posterior)
13321325

1333-
@nnx.jit
1334-
def decode(
1335-
self, z: jax.Array, return_dict: bool = True
1336-
) -> Union[FlaxDecoderOutput, jax.Array]:
1326+
def _decode_jit(self, z: jax.Array) -> jax.Array:
1327+
"""Core JAX decoding logic with scan and frame swapping."""
13371328
if z.shape[-1] != self.z_dim:
13381329
z = jnp.transpose(z, (0, 2, 3, 4, 1))
1339-
13401330
x, _ = self.post_quant_conv(z)
1341-
x_scan = jnp.swapaxes(x, 0, 1)
1342-
1331+
13431332
b, t, h, w, c = x.shape
13441333
init_cache = self.decoder.init_cache(b, h, w, x.dtype)
1334+
decoder_checkpointed = jax.checkpoint(self.decoder)
1335+
1336+
all_decoded = []
1337+
x_first = x[:, :1, ...]
1338+
out_first, state_carry = decoder_checkpointed(x_first, init_cache)
1339+
all_decoded.append(out_first)
1340+
if t > 1:
1341+
x_rest = x[:, 1:, ...]
1342+
x_scan = jnp.swapaxes(x_rest, 0, 1)
1343+
1344+
def scan_fn(carry, input_slice):
1345+
input_slice = jnp.expand_dims(input_slice, 1)
1346+
out_slice, new_carry = decoder_checkpointed(input_slice, carry)
1347+
out_swapped = out_slice[:, jnp.array([0, 2, 1, 3]), ...]
1348+
1349+
return new_carry, out_swapped
1350+
1351+
_, decoded_rest = jax.lax.scan(scan_fn, state_carry, x_scan)
1352+
1353+
decoded_rest = jnp.swapaxes(decoded_rest, 0, 1)
1354+
1355+
b_r, t_r, sub_t, h_r, w_r, c_r = decoded_rest.shape
1356+
decoded_rest = decoded_rest.reshape(b_r, t_r * sub_t, h_r, w_r, c_r)
1357+
1358+
all_decoded.append(decoded_rest)
1359+
1360+
out = jnp.concatenate(all_decoded, axis=1)
1361+
out = jnp.clip(out, min=-1.0, max=1.0)
1362+
1363+
return out
1364+
_decode_compiled = nnx.jit(_decode_jit)
13451365

1346-
def scan_fn(carry, input_slice):
1347-
input_slice = jnp.expand_dims(input_slice, 1)
1348-
out_slice, new_carry = self.decoder(input_slice, carry)
1349-
return new_carry, out_slice
1350-
1351-
# Need to provide a valid initial cache structure for the scan
1352-
final_cache, decoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1353-
1354-
decoded = jnp.transpose(decoded_frames, (1, 0, 2, 3, 4, 5))
1355-
1356-
b, t_lat, t_sub, h, w, c = decoded.shape
1357-
decoded = decoded.reshape(b, t_lat * t_sub, h, w, c)
1358-
1359-
out = jnp.clip(decoded, min=-1.0, max=1.0)
1366+
def decode(
1367+
self, z: jax.Array, return_dict: bool = True
1368+
) -> Union[FlaxDecoderOutput, jax.Array]:
1369+
1370+
decoded = self._decode_compiled(z)
13601371

13611372
if not return_dict:
1362-
return (out,)
1363-
return FlaxDecoderOutput(sample=out)
1373+
return (decoded,)
1374+
1375+
return FlaxDecoderOutput(sample=decoded)

0 commit comments

Comments
 (0)