Skip to content

Commit 720031a

Browse files
committed
encode decode modified
1 parent 283c52d commit 720031a

1 file changed

Lines changed: 18 additions & 14 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ def __init__(
12591259
precision=precision,
12601260
)
12611261

1262-
@nnx.jit
1262+
@nnx.jit # JIT the whole encode method
12631263
def encode(
12641264
self, x: jax.Array, return_dict: bool = True
12651265
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
@@ -1269,13 +1269,14 @@ def encode(
12691269

12701270
b, t, h, w, c = x.shape
12711271
chunk_size = 4 # Process in chunks of 4 frames
1272+
# Assuming the encoder downsamples time by a factor of 4 overall
12721273

1273-
# Calculate padding needed to make the time dimension a multiple of chunk_size
12741274
num_chunks = math.ceil(t / chunk_size)
12751275
padded_t = num_chunks * chunk_size
12761276
padding_t = padded_t - t
12771277

12781278
if padding_t > 0:
1279+
# Pad the time dimension to be a multiple of chunk_size
12791280
paddings = [(0, 0)] * x.ndim
12801281
paddings[1] = (0, padding_t) # Pad at the end of the time dimension
12811282
x_padded = jnp.pad(x, paddings, mode='constant', constant_values=0.0)
@@ -1288,23 +1289,26 @@ def encode(
12881289
# Swap axes for scan: (Num_Chunks, B, Chunk_T, H, W, C)
12891290
x_scannable = jnp.swapaxes(x_reshaped, 0, 1)
12901291

1291-
# Define the function to be executed in each step of the scan
1292+
# Wrap the encoder's call method with jax.checkpoint
1293+
# nnx.Module instances are callable, so this works.
1294+
encoder_checkpointed = jax.checkpoint(self.encoder)
1295+
12921296
def scan_fn(dummy_carry, x_chunk):
12931297
# x_chunk shape: (B, chunk_size, H, W, C)
12941298
b_c, _, h_c, w_c, _ = x_chunk.shape
12951299

1296-
# Reset cache for each chunk to ensure independence
1300+
# Reset cache for each chunk to ensure independence as per original logic
12971301
init_cache = self.encoder.init_cache(b_c, h_c, w_c, x_chunk.dtype)
12981302

1299-
# Use gradient checkpointing to save memory
1300-
out_chunk, _ = nnx.checkpoint(self.encoder)(x_chunk, init_cache)
1301-
# Expected out_chunk shape: (B, 1, H', W', Z*2)
1302-
# as each 4-frame chunk is downsampled temporally by 4x.
1303+
# Call the checkpointed encoder
1304+
out_chunk, _ = encoder_checkpointed(x_chunk, init_cache)
1305+
# Expected out_chunk shape: (B, 1, H', W', Z*2), assuming 4x temporal downsampling per chunk
13031306

13041307
return dummy_carry, out_chunk
13051308

1306-
# Initial carry for scan - not used for state propagation between chunks
1307-
initial_scan_carry = self.encoder.init_cache(b, h, w, x.dtype)
1309+
# The initial carry structure for scan needs to match the output carry structure of scan_fn.
1310+
# Since we don't propagate the cache *between* chunks, dummy_carry can be simple.
1311+
initial_scan_carry = {}
13081312

13091313
# Run the scan over the chunks
13101314
_, encoded_chunks = jax.lax.scan(scan_fn, initial_scan_carry, x_scannable)
@@ -1314,11 +1318,10 @@ def scan_fn(dummy_carry, x_chunk):
13141318
# Transpose back to (B, num_chunks, 1, H', W', Z*2)
13151319
encoded_combined = jnp.swapaxes(encoded_chunks, 0, 1)
13161320

1317-
# Reshape to (B, num_chunks, H', W', Z*2)
1321+
# Reshape to (B, num_chunks * 1, H', W', Z*2) -> (B, num_chunks, H', W', Z*2)
13181322
b_out, nc_out, t_out_chunk, h_out, w_out, c_out = encoded_combined.shape
13191323
encoded = encoded_combined.reshape((b_out, nc_out * t_out_chunk, h_out, w_out, c_out))
1320-
# Final 'encoded' shape: (B, num_chunks, H', W', Z*2)
1321-
# For T=9, num_chunks=3. This matches the expected (B, 3, H', W', Z*2)
1324+
# Final 'encoded' shape: (B, 3, H', W', Z*2) for T=9 input
13221325

13231326
# Post-processing to get distribution parameters
13241327
enc, _ = self.quant_conv(encoded, cache_x=None)
@@ -1350,7 +1353,8 @@ def scan_fn(carry, input_slice):
13501353
out_slice, new_carry = self.decoder(input_slice, carry)
13511354
return new_carry, out_slice
13521355

1353-
final_cache, decoded_frames = jax.lax.scan(scan_fn, initial_scan_carry, x_scan)
1356+
# Need to provide a valid initial cache structure for the scan
1357+
final_cache, decoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
13541358

13551359
decoded = jnp.transpose(decoded_frames, (1, 0, 2, 3, 4, 5))
13561360

0 commit comments

Comments
 (0)