Skip to content

Commit b7b6730

Browse files
committed
separate encode methods
1 parent 720031a commit b7b6730

1 file changed

Lines changed: 23 additions & 28 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,17 +1259,14 @@ def __init__(
12591259
precision=precision,
12601260
)
12611261

1262-
@nnx.jit # JIT the whole encode method
1263-
def encode(
1264-
self, x: jax.Array, return_dict: bool = True
1265-
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
1262+
def _encode_jit(self, x: jax.Array) -> jax.Array:
1263+
"""Contains the core JAX computations for encoding, suitable for JIT."""
12661264
if x.shape[-1] != 3:
12671265
x = jnp.transpose(x, (0, 2, 3, 4, 1))
1268-
assert x.shape[-1] == 3, "Input channels must be 3"
1266+
# assert x.shape[-1] == 3, "Input channels must be 3" # Assertions might not be ideal in JIT
12691267

12701268
b, t, h, w, c = x.shape
12711269
chunk_size = 4 # Process in chunks of 4 frames
1272-
# Assuming the encoder downsamples time by a factor of 4 overall
12731270

12741271
num_chunks = math.ceil(t / chunk_size)
12751272
padded_t = num_chunks * chunk_size
@@ -1289,47 +1286,45 @@ def encode(
12891286
# Swap axes for scan: (Num_Chunks, B, Chunk_T, H, W, C)
12901287
x_scannable = jnp.swapaxes(x_reshaped, 0, 1)
12911288

1292-
# Wrap the encoder's call method with jax.checkpoint
1293-
# nnx.Module instances are callable, so this works.
12941289
encoder_checkpointed = jax.checkpoint(self.encoder)
12951290

12961291
def scan_fn(dummy_carry, x_chunk):
12971292
# x_chunk shape: (B, chunk_size, H, W, C)
12981293
b_c, _, h_c, w_c, _ = x_chunk.shape
1299-
1300-
# Reset cache for each chunk to ensure independence as per original logic
13011294
init_cache = self.encoder.init_cache(b_c, h_c, w_c, x_chunk.dtype)
1302-
1303-
# Call the checkpointed encoder
13041295
out_chunk, _ = encoder_checkpointed(x_chunk, init_cache)
1305-
# Expected out_chunk shape: (B, 1, H', W', Z*2), assuming 4x temporal downsampling per chunk
1306-
13071296
return dummy_carry, out_chunk
13081297

1309-
# The initial carry structure for scan needs to match the output carry structure of scan_fn.
1310-
# Since we don't propagate the cache *between* chunks, dummy_carry can be simple.
13111298
initial_scan_carry = {}
1312-
1313-
# Run the scan over the chunks
13141299
_, encoded_chunks = jax.lax.scan(scan_fn, initial_scan_carry, x_scannable)
1315-
# encoded_chunks shape: (num_chunks, B, 1, H', W', Z*2)
13161300

1317-
# Concatenate the results from each chunk
1318-
# Transpose back to (B, num_chunks, 1, H', W', Z*2)
13191301
encoded_combined = jnp.swapaxes(encoded_chunks, 0, 1)
13201302

1321-
# Reshape to (B, num_chunks * 1, H', W', Z*2) -> (B, num_chunks, H', W', Z*2)
13221303
b_out, nc_out, t_out_chunk, h_out, w_out, c_out = encoded_combined.shape
13231304
encoded = encoded_combined.reshape((b_out, nc_out * t_out_chunk, h_out, w_out, c_out))
1324-
# Final 'encoded' shape: (B, 3, H', W', Z*2) for T=9 input
13251305

1326-
# Post-processing to get distribution parameters
13271306
enc, _ = self.quant_conv(encoded, cache_x=None)
1328-
mu = enc[..., :self.z_dim]
1329-
logvar = enc[..., self.z_dim:]
1330-
h = jnp.concatenate([mu, logvar], axis=-1)
1307+
# mu = enc[..., :self.z_dim]
1308+
# logvar = enc[..., self.z_dim:]
1309+
# h = jnp.concatenate([mu, logvar], axis=-1)
1310+
return enc # Return the direct output of quant_conv
1311+
1312+
# JIT compile the internal JAX-based function
1313+
_encode_compiled = nnx.jit(_encode_jit)
1314+
1315+
def encode(
1316+
self, x: jax.Array, return_dict: bool = True
1317+
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
1318+
"""Encodes the input array and returns custom distribution objects."""
1319+
if x.shape[-1] != 3:
1320+
# Transpose in the non-JIT part if needed, though _encode_jit handles it too
1321+
pass # Handled inside _encode_jit
1322+
1323+
# Call the JIT-compiled function to get the raw encoded array
1324+
h_params = self._encode_compiled(x)
13311325

1332-
posterior = FlaxDiagonalGaussianDistribution(h)
1326+
# Create the custom Python objects from the JAX array results
1327+
posterior = FlaxDiagonalGaussianDistribution(h_params)
13331328

13341329
if not return_dict:
13351330
return (posterior,)

0 commit comments

Comments
 (0)