Skip to content

Commit 95b0141

Browse files
committed
jax.lax.scan
1 parent 18d167c commit 95b0141

1 file changed

Lines changed: 69 additions & 34 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,17 +1128,36 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11281128
iter_ = 1 + (t - 1) // 4
11291129
enc_feat_map = feat_cache._enc_feat_map
11301130

1131-
for i in range(iter_):
1132-
enc_conv_idx = 0
1133-
if i == 0:
1134-
out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx)
1135-
else:
1136-
out_, enc_feat_map, enc_conv_idx = self.encoder(
1137-
x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :],
1138-
feat_cache=enc_feat_map,
1139-
feat_idx=enc_conv_idx,
1140-
)
1141-
out = jnp.concatenate([out, out_], axis=1)
1131+
# Process first chunk explicitly
1132+
out_first, enc_feat_map, _ = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=0)
1133+
1134+
# Prepare remaining chunks for scan
1135+
def scan_body_encode(carry, x_chunk):
1136+
feat_map = carry
1137+
out_chunk, updated_feat_map, _ = self.encoder(x_chunk, feat_cache=feat_map, feat_idx=0)
1138+
return updated_feat_map, out_chunk
1139+
1140+
if iter_ > 1:
1141+
# We have remaining chunks to process. Let's reshape/stack them.
1142+
# x is (B, T, H, W, C) where T = 1 + 4 * (iter_ - 1)
1143+
# We want to scan over the iter_-1 blocks of size 4.
1144+
x_rest = x[:, 1:, :, :, :]
1145+
b, t_rest, h, w, c = x_rest.shape
1146+
x_rest_blocks = x_rest.reshape(b, iter_ - 1, 4, h, w, c)
1147+
# scan over the blocks dimension (axis=1) -> swap axis 0 and 1
1148+
x_scan_input = jnp.swapaxes(x_rest_blocks, 0, 1) # shape: (iter_ - 1, B, 4, H, W, C)
1149+
1150+
enc_feat_map, out_rest_stacked = jax.lax.scan(scan_body_encode, enc_feat_map, x_scan_input)
1151+
# out_rest_stacked shape: (iter_ - 1, B, T_out_chunk, H_out, W_out, C_out)
1152+
1153+
# Transpose back and flatten the iteration and time dimensions
1154+
out_rest_stacked = jnp.swapaxes(out_rest_stacked, 0, 1)
1155+
b_out, iters_out, t_out_chunk, h_out, w_out, c_out = out_rest_stacked.shape
1156+
out_rest = out_rest_stacked.reshape(b_out, iters_out * t_out_chunk, h_out, w_out, c_out)
1157+
1158+
out = jnp.concatenate([out_first, out_rest], axis=1)
1159+
else:
1160+
out = out_first
11421161

11431162
enc = self.quant_conv(out)
11441163
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
@@ -1166,29 +1185,45 @@ def _decode(
11661185

11671186
dec_feat_map = feat_cache._feat_map
11681187

1169-
for i in range(iter_):
1170-
conv_idx = 0
1171-
if i == 0:
1172-
out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
1173-
else:
1174-
out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
1175-
1176-
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1177-
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1178-
# Most likely due to an incorrect reshaping in the decoder.
1179-
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
1180-
# When batch_size is 0, expand batch dim for concatenation
1181-
# else, expand frame dim for concatenation so that batch dim stays intact.
1182-
axis = 0
1183-
if fm1.shape[0] > 1:
1184-
axis = 1
1185-
1186-
if len(fm1.shape) == 4:
1187-
fm1 = jnp.expand_dims(fm1, axis=axis)
1188-
fm2 = jnp.expand_dims(fm2, axis=axis)
1189-
fm3 = jnp.expand_dims(fm3, axis=axis)
1190-
fm4 = jnp.expand_dims(fm4, axis=axis)
1191-
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
1188+
def process_out_frame(out_):
1189+
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1190+
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1191+
# Most likely due to an incorrect reshaping in the decoder.
1192+
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
1193+
axis = 0
1194+
if fm1.shape[0] > 1:
1195+
axis = 1
1196+
1197+
if len(fm1.shape) == 4:
1198+
fm1 = jnp.expand_dims(fm1, axis=axis)
1199+
fm2 = jnp.expand_dims(fm2, axis=axis)
1200+
fm3 = jnp.expand_dims(fm3, axis=axis)
1201+
fm4 = jnp.expand_dims(fm4, axis=axis)
1202+
return jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1203+
1204+
# Process first chunk explicitly
1205+
out_first, dec_feat_map, _ = self.decoder(x[:, :1, :, :, :], feat_cache=dec_feat_map, feat_idx=0)
1206+
1207+
def scan_body_decode(carry, x_chunk):
1208+
feat_map = carry
1209+
out_chunk, updated_feat_map, _ = self.decoder(x_chunk, feat_cache=feat_map, feat_idx=0)
1210+
out_processed = process_out_frame(out_chunk)
1211+
return updated_feat_map, out_processed
1212+
1213+
if iter_ > 1:
1214+
x_rest = x[:, 1:, :, :, :]
1215+
# Scan over the time dimension directly
1216+
x_scan_input = jnp.swapaxes(jnp.expand_dims(x_rest, axis=2), 0, 1) # shape: (iter_ - 1, B, 1, H, W, C)
1217+
1218+
dec_feat_map, out_rest_stacked = jax.lax.scan(scan_body_decode, dec_feat_map, x_scan_input)
1219+
1220+
out_rest_stacked = jnp.swapaxes(out_rest_stacked, 0, 1)
1221+
b_out, iters_out, t_out_frames, h_out, w_out, c_out = out_rest_stacked.shape
1222+
out_rest = out_rest_stacked.reshape(b_out, iters_out * t_out_frames, h_out, w_out, c_out)
1223+
1224+
out = jnp.concatenate([out_first, out_rest], axis=1)
1225+
else:
1226+
out = out_first
11921227

11931228
out = jnp.clip(out, min=-1.0, max=1.0)
11941229
return out

0 commit comments

Comments
 (0)