Skip to content

Commit 94cc109

Browse files
committed
encode restored
1 parent ec668df commit 94cc109

1 file changed

Lines changed: 14 additions & 17 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,34 +1258,31 @@ def __init__(
12581258
precision=precision,
12591259
)
12601260

1261-
# REMOVE @nnx.jit for now to ensure this logic runs
1261+
@nnx.jit
12621262
def encode(
12631263
self, x: jax.Array, return_dict: bool = True
12641264
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
12651265
if x.shape[-1] != 3:
1266+
# reshape channel last for JAX
12661267
x = jnp.transpose(x, (0, 2, 3, 4, 1))
12671268
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
12681269

1270+
x_scan = jnp.swapaxes(x, 0, 1)
12691271
b, t, h, w, c = x.shape
12701272
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
12711273

1272-
# Process first frame
1273-
out1, _ = self.encoder(x[:, :1, ...], init_cache)
1274-
1275-
if t > 1:
1276-
# Process remaining frames in one chunk
1277-
# We need to manage cache updates manually if not using scan
1278-
# This part is tricky because the new encoder returns cache,
1279-
# but the old logic didn't seem to carry cache between chunks.
1274+
def scan_fn(carry, input_slice):
1275+
# Expand Time dimension for Conv3d
1276+
input_slice = jnp.expand_dims(input_slice, 1)
1277+
out_slice, new_carry = self.encoder(input_slice, carry)
1278+
# Squeeze Time dimension for scan stacking
1279+
out_slice = jnp.squeeze(out_slice, 1)
1280+
return new_carry, out_slice
12801281

1281-
# Let's SIMPLIFY to match the OLD logic's spirit: Reset cache for the chunk
1282-
init_cache_rest = self.encoder.init_cache(b, h, w, x.dtype)
1283-
out_rest, _ = self.encoder(x[:, 1:, ...], init_cache_rest)
1284-
encoded = jnp.concatenate([out1, out_rest], axis=1)
1285-
else:
1286-
encoded = out1
1282+
final_cache, encoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1283+
encoded = jnp.swapaxes(encoded_frames, 0, 1)
1284+
enc, _ = self.quant_conv(encoded)
12871285

1288-
enc, _ = self.quant_conv(encoded, cache_x=None)
12891286
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
12901287
h = jnp.concatenate([mu, logvar], axis=-1)
12911288

@@ -1294,7 +1291,7 @@ def encode(
12941291
return (posterior,)
12951292
return FlaxAutoencoderKLOutput(latent_dist=posterior)
12961293

1297-
# @nnx.jit
1294+
@nnx.jit
12981295
def decode(
12991296
self, z: jax.Array, return_dict: bool = True
13001297
) -> Union[FlaxDecoderOutput, jax.Array]:

0 commit comments

Comments
 (0)