Skip to content

Commit 1f712f8

Browse files
committed
nonetype error fixed
1 parent 4fdabf7 commit 1f712f8

1 file changed

Lines changed: 28 additions & 78 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 28 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,84 +1197,34 @@ def __init__(
11971197
def encode(
11981198
self, x: jax.Array, return_dict: bool = True
11991199
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
1200-
if x.shape[-1] != 3:
1201-
x = jnp.transpose(x, (0, 2, 3, 4, 1))
1202-
1203-
b, t, h, w, c = x.shape
1204-
1205-
# --- STEP 1: Process First Frame (With Padding Hack) ---
1206-
x_first = x[:, :1, ...] # (B, 1, ...)
1207-
1208-
# We PAD this single frame to T=4 so it survives the strides.
1209-
# We will take only the first result.
1210-
x_first_padded = jnp.concatenate([x_first] * 4, axis=1) # (B, 4, ...)
1211-
1212-
# Initialize Cache
1213-
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
1214-
1215-
# Run Encoder on padded first frame
1216-
# We discard the cache update here because this is a "Fake" run to get the latent
1217-
# BUT wait, we need the cache state for the next frames.
1218-
# This is tricky. If we pad [0, 0, 0, 0], the cache will be filled with Frame 0's history.
1219-
# This is actually correct for a static image or start of video.
1220-
1221-
enc_first_padded, cache_after_first = self.encoder(x_first_padded, init_cache)
1222-
1223-
# Take only the first frame of the output
1224-
enc_first = enc_first_padded[:, :1, ...]
1225-
1226-
# --- STEP 2: Process Rest of Frames (Chunks of 4) ---
1227-
x_rest = x[:, 1:, ...]
1228-
t_rest = t - 1
1229-
1230-
# Pad remainder to be divisible by 4
1231-
pad_len = (4 - (t_rest % 4)) % 4
1232-
if pad_len > 0:
1233-
last = x_rest[:, -1:, ...]
1234-
padding = jnp.repeat(last, pad_len, axis=1)
1235-
x_rest_padded = jnp.concatenate([x_rest, padding], axis=1)
1236-
else:
1237-
x_rest_padded = x_rest
1238-
1239-
num_chunks = x_rest_padded.shape[1] // 4
1240-
x_chunks = x_rest_padded.reshape(b, num_chunks, 4, h, w, c)
1241-
x_chunks = jnp.transpose(x_chunks, (1, 0, 2, 3, 4, 5))
1242-
1243-
# Scan Function
1244-
def scan_fn(carry, input_chunk):
1245-
out_chunk, new_carry = self.encoder(input_chunk, carry)
1246-
return new_carry, out_chunk
1247-
1248-
final_cache, enc_rest_chunks = jax.lax.scan(scan_fn, cache_after_first, x_chunks)
1249-
1250-
# Flatten Rest
1251-
enc_rest_chunks = jnp.swapaxes(enc_rest_chunks, 0, 1)
1252-
b_out, n_chunks, t_chunk, h_out, w_out, c_out = enc_rest_chunks.shape
1253-
enc_rest = enc_rest_chunks.reshape(b_out, n_chunks * t_chunk, h_out, w_out, c_out)
1254-
1255-
# Slice off padding from result if needed
1256-
# We padded input by 'pad_len'. Output is downsampled by 4 (likely).
1257-
# Actually, since we chunked by 4 and got 1 output, the mapping is 1-to-1 chunk-to-latent.
1258-
# If we added 1 chunk of padding, we remove 1 frame of output.
1259-
if pad_len > 0:
1260-
# We padded inputs. Does that mean we generated extra latents?
1261-
# If t_rest=5. Pad to 8. Chunks=2. Output=2 latents.
1262-
# Real latents needed: ceil(5/4) = 2.
1263-
# So actually, we don't need to slice! The ceiling behavior is what we want.
1264-
pass
1265-
1266-
# Concatenate
1267-
encoded = jnp.concatenate([enc_first, enc_rest], axis=1)
1268-
1269-
# Quantize
1270-
enc, _ = self.quant_conv(encoded)
1271-
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
1272-
h_latents = jnp.concatenate([mu, logvar], axis=-1)
1273-
1274-
posterior = FlaxDiagonalGaussianDistribution(h_latents)
1275-
if not return_dict:
1276-
return (posterior,)
1277-
return FlaxAutoencoderKLOutput(latent_dist=posterior)
1200+
if x.shape[-1] != 3:
1201+
# reshape channel last for JAX
1202+
x = jnp.transpose(x, (0, 2, 3, 4, 1))
1203+
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
1204+
1205+
x_scan = jnp.swapaxes(x, 0, 1)
1206+
b, t, h, w, c = x.shape
1207+
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
1208+
1209+
def scan_fn(carry, input_slice):
1210+
# Expand Time dimension for Conv3d
1211+
input_slice = jnp.expand_dims(input_slice, 1)
1212+
out_slice, new_carry = self.encoder(input_slice, carry)
1213+
# Squeeze Time dimension for scan stacking
1214+
out_slice = jnp.squeeze(out_slice, 1)
1215+
return new_carry, out_slice
1216+
1217+
final_cache, encoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1218+
encoded = jnp.swapaxes(encoded_frames, 0, 1)
1219+
enc, _ = self.quant_conv(encoded)
1220+
1221+
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
1222+
h = jnp.concatenate([mu, logvar], axis=-1)
1223+
1224+
posterior = FlaxDiagonalGaussianDistribution(h)
1225+
if not return_dict:
1226+
return (posterior,)
1227+
return FlaxAutoencoderKLOutput(latent_dist=posterior)
12781228

12791229
@nnx.jit
12801230
def decode(

0 commit comments

Comments
 (0)