Skip to content

Commit 2ada7ec

Browse files
committed
fix
1 parent ce8ab0d commit 2ada7ec

1 file changed

Lines changed: 60 additions & 30 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,27 +1123,44 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11231123
x = jnp.transpose(x, (0, 2, 3, 4, 1))
11241124
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
11251125

1126-
# Swap to (T, B, H, W, C) for scanning over time
1127-
x_scan = jnp.swapaxes(x, 0, 1)
1126+
t = x.shape[1]
11281127
enc_feat_map = feat_cache._enc_feat_map
11291128

1130-
def scan_fn(carry_cache, input_frame):
1131-
# Expand time dimension to 1 for the encoder
1132-
input_frame = jnp.expand_dims(input_frame, 1)
1133-
# idx is restarted at 0 for each chunk/frame conceptually
1134-
enc_conv_idx = 0
1135-
out_frame, new_cache, _ = self.encoder(input_frame, feat_cache=carry_cache, feat_idx=enc_conv_idx)
1136-
out_frame = jnp.squeeze(out_frame, 1)
1137-
return new_cache, out_frame
1138-
1139-
# Perform JAX scan
1140-
final_enc_feat_map, encoded_frames = jax.lax.scan(scan_fn, enc_feat_map, x_scan)
1141-
1142-
# Swap back to (B, T, ... )
1143-
out = jnp.swapaxes(encoded_frames, 0, 1)
1129+
# 1. Evaluate the first frame manually to establish the initial cache with JAX Arrays.
1130+
# This prevents jax.lax.scan from crashing on type mismatch between None and ShapedArray.
1131+
out_0, enc_feat_map, _ = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=0)
1132+
out = out_0
11441133

1145-
# Update back to the wrapper object if needed
1146-
feat_cache._enc_feat_map = final_enc_feat_map
1134+
# 2. Process remaining frames in chunks of 4 using jax.lax.scan
1135+
if t > 1:
1136+
x_rest = x[:, 1:, :, :, :]
1137+
B, T_rest, H, W, C = x_rest.shape
1138+
num_chunks = T_rest // 4
1139+
1140+
# Reshape to (B, num_chunks, 4, H, W, C)
1141+
x_chunks = jnp.reshape(x_rest, (B, num_chunks, 4, H, W, C))
1142+
1143+
# Swap axes for scan traversal: (num_chunks, B, 4, H, W, C)
1144+
x_scan = jnp.swapaxes(x_chunks, 0, 1)
1145+
1146+
def scan_fn(carry_cache, input_chunk):
1147+
# input_chunk shape: (B, 4, H, W, C)
1148+
out_chunk, new_cache, _ = self.encoder(input_chunk, feat_cache=carry_cache, feat_idx=0)
1149+
# out_chunk shape: (B, 1, H', W', C')
1150+
return new_cache, out_chunk
1151+
1152+
enc_feat_map, scanned_out_chunks = jax.lax.scan(scan_fn, enc_feat_map, x_scan)
1153+
1154+
# scanned_out_chunks shape: (num_chunks, B, 1, H', W', C')
1155+
scanned_out_chunks = jnp.swapaxes(scanned_out_chunks, 0, 1)
1156+
1157+
B_out, _, _, H_out, W_out, C_out = scanned_out_chunks.shape
1158+
scanned_out_chunks = jnp.reshape(scanned_out_chunks, (B_out, num_chunks, H_out, W_out, C_out))
1159+
1160+
out = jnp.concatenate([out_0, scanned_out_chunks], axis=1)
1161+
1162+
# 3. Update back to the wrapper object if needed
1163+
feat_cache._enc_feat_map = enc_feat_map
11471164

11481165
enc = self.quant_conv(out)
11491166
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
@@ -1170,22 +1187,35 @@ def _decode(
11701187

11711188
dec_feat_map = feat_cache._feat_map
11721189

1173-
# Evaluate the first frame manually to establish the initial cache.
1174-
# The decoder returns 1 frame on the first step, and 4 frames on subsequent steps due to temporal upsampling.
1190+
# 1. Evaluate the first frame manually (Cache: None -> RepSentinel/ShapedArray)
1191+
# The decoder returns 1 frame on the first step.
11751192
out_0, dec_feat_map, _ = self.decoder(x[:, 0:1, :, :, :], feat_cache=dec_feat_map, feat_idx=0)
11761193
out = out_0
11771194

1178-
# Process remaining frames using jax.lax.scan (requires homogenous output shapes)
1195+
# 2. Evaluate the second frame manually (Cache: RepSentinel -> ShapedArray)
1196+
# This ensures that ALL cache components are ShapredArrays before entering jax.lax.scan,
1197+
# preventing TraceContext errors due to type mismatches.
11791198
if iter_ > 1:
1180-
x_rest = x[:, 1:, :, :, :]
1181-
x_scan = jnp.swapaxes(x_rest, 0, 1) # (T-1, B, H, W, C)
1199+
out_1, dec_feat_map, _ = self.decoder(x[:, 1:2, :, :, :], feat_cache=dec_feat_map, feat_idx=0)
1200+
1201+
# Bypass an issue where frame[1] should be frame[2] and vice versa.
1202+
fm1 = out_1[:, 0:1, ...]
1203+
fm2 = out_1[:, 1:2, ...]
1204+
fm3 = out_1[:, 2:3, ...]
1205+
fm4 = out_1[:, 3:4, ...]
1206+
out_1_fixed = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1207+
out = jnp.concatenate([out_0, out_1_fixed], axis=1)
1208+
1209+
# 3. Process remaining frames using jax.lax.scan (requires homogenous output and carry shapes)
1210+
if iter_ > 2:
1211+
x_rest = x[:, 2:, :, :, :]
1212+
x_scan = jnp.swapaxes(x_rest, 0, 1) # (T-2, B, H, W, C)
11821213

11831214
def scan_fn(carry_cache, input_frame):
11841215
input_frame = jnp.expand_dims(input_frame, 1) # (B, 1, H, W, C)
11851216
out_frames, new_cache, _ = self.decoder(input_frame, feat_cache=carry_cache, feat_idx=0)
11861217

11871218
# Bypass an issue where frame[1] should be frame[2] and vice versa.
1188-
# Ensure dimensionality allows straightforward slicing:
11891219
fm1 = out_frames[:, 0:1, ...]
11901220
fm2 = out_frames[:, 1:2, ...]
11911221
fm3 = out_frames[:, 2:3, ...]
@@ -1196,17 +1226,17 @@ def scan_fn(carry_cache, input_frame):
11961226

11971227
dec_feat_map, scanned_out_frames = jax.lax.scan(scan_fn, dec_feat_map, x_scan)
11981228

1199-
# scanned_out_frames is (T-1, B, 4, H, W, C)
1229+
# scanned_out_frames is (T-2, B, 4, H, W, C)
12001230
B = scanned_out_frames.shape[1]
1201-
T_minus_1 = scanned_out_frames.shape[0]
1231+
T_minus_2 = scanned_out_frames.shape[0]
12021232
H, W, C = scanned_out_frames.shape[3], scanned_out_frames.shape[4], scanned_out_frames.shape[5]
12031233

1204-
# Swap back to (B, T-1, 4, H, W, C)
1234+
# Swap back to (B, T-2, 4, H, W, C)
12051235
scanned_out_frames = jnp.swapaxes(scanned_out_frames, 0, 1)
1206-
# Flatten the temporal axes to (B, (T-1)*4, H, W, C)
1207-
scanned_out_frames = jnp.reshape(scanned_out_frames, (B, T_minus_1 * 4, H, W, C))
1236+
# Flatten the temporal axes to (B, (T-2)*4, H, W, C)
1237+
scanned_out_frames = jnp.reshape(scanned_out_frames, (B, T_minus_2 * 4, H, W, C))
12081238

1209-
out = jnp.concatenate([out_0, scanned_out_frames], axis=1)
1239+
out = jnp.concatenate([out, scanned_out_frames], axis=1)
12101240

12111241
feat_cache._feat_map = dec_feat_map
12121242

0 commit comments

Comments
 (0)