Skip to content

Commit c1af338

Browse files
committed
trying full jit compile with spatial sharding
1 parent c5fb919 commit c1af338

1 file changed

Lines changed: 55 additions & 38 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) ->
156156
shard_width_axis = "context"
157157

158158
x_padded = jax.lax.with_sharding_constraint(
159-
x_padded, jax.sharding.PartitionSpec(None, None, shard_axis, shard_width_axis, None)
159+
x_padded, jax.sharding.PartitionSpec("data", None, shard_axis, shard_width_axis, None)
160160
)
161161

162162
out = self.conv(x_padded)
@@ -1125,24 +1125,27 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11251125
x = jnp.transpose(x, (0, 2, 3, 4, 1))
11261126
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
11271127

1128-
t = x.shape[1]
1129-
iter_ = 1 + (t - 1) // 4
1128+
# Swap to (T, B, H, W, C) for scanning over time
1129+
x_scan = jnp.swapaxes(x, 0, 1)
11301130
enc_feat_map = feat_cache._enc_feat_map
11311131

1132-
for i in range(iter_):
1132+
def scan_fn(carry_cache, input_frame):
1133+
# Expand time dimension to 1 for the encoder
1134+
input_frame = jnp.expand_dims(input_frame, 1)
1135+
# idx is restarted at 0 for each chunk/frame conceptually
11331136
enc_conv_idx = 0
1134-
if i == 0:
1135-
out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx)
1136-
else:
1137-
out_, enc_feat_map, enc_conv_idx = self.encoder(
1138-
x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :],
1139-
feat_cache=enc_feat_map,
1140-
feat_idx=enc_conv_idx,
1141-
)
1142-
out = jnp.concatenate([out, out_], axis=1)
1137+
out_frame, new_cache, _ = self.encoder(input_frame, feat_cache=carry_cache, feat_idx=enc_conv_idx)
1138+
out_frame = jnp.squeeze(out_frame, 1)
1139+
return new_cache, out_frame
11431140

1144-
# Update back to the wrapper object if needed, but for result we use local vars
1145-
feat_cache._enc_feat_map = enc_feat_map
1141+
# Perform JAX scan
1142+
final_enc_feat_map, encoded_frames = jax.lax.scan(scan_fn, enc_feat_map, x_scan)
1143+
1144+
# Swap back to (B, T, ... )
1145+
out = jnp.swapaxes(encoded_frames, 0, 1)
1146+
1147+
# Update back to the wrapper object if needed
1148+
feat_cache._enc_feat_map = final_enc_feat_map
11461149

11471150
enc = self.quant_conv(out)
11481151
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
@@ -1169,29 +1172,43 @@ def _decode(
11691172

11701173
dec_feat_map = feat_cache._feat_map
11711174

1172-
for i in range(iter_):
1173-
conv_idx = 0
1174-
if i == 0:
1175-
out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
1176-
else:
1177-
out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
1178-
1179-
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1180-
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1181-
# Most likely due to an incorrect reshaping in the decoder.
1182-
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
1183-
# When batch_size is 0, expand batch dim for concatenation
1184-
# else, expand frame dim for concatenation so that batch dim stays intact.
1185-
axis = 0
1186-
if fm1.shape[0] > 1:
1187-
axis = 1
1188-
1189-
if len(fm1.shape) == 4:
1190-
fm1 = jnp.expand_dims(fm1, axis=axis)
1191-
fm2 = jnp.expand_dims(fm2, axis=axis)
1192-
fm3 = jnp.expand_dims(fm3, axis=axis)
1193-
fm4 = jnp.expand_dims(fm4, axis=axis)
1194-
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
1175+
# Evaluate the first frame manually to establish the initial cache.
1176+
# The decoder returns 1 frame on the first step, and 4 frames on subsequent steps due to temporal upsampling.
1177+
out_0, dec_feat_map, _ = self.decoder(x[:, 0:1, :, :, :], feat_cache=dec_feat_map, feat_idx=0)
1178+
out = out_0
1179+
1180+
# Process remaining frames using jax.lax.scan (requires homogenous output shapes)
1181+
if iter_ > 1:
1182+
x_rest = x[:, 1:, :, :, :]
1183+
x_scan = jnp.swapaxes(x_rest, 0, 1) # (T-1, B, H, W, C)
1184+
1185+
def scan_fn(carry_cache, input_frame):
1186+
input_frame = jnp.expand_dims(input_frame, 1) # (B, 1, H, W, C)
1187+
out_frames, new_cache, _ = self.decoder(input_frame, feat_cache=carry_cache, feat_idx=0)
1188+
1189+
# Bypass an issue where frame[1] should be frame[2] and vice versa.
1190+
# Ensure dimensionality allows straightforward slicing:
1191+
fm1 = out_frames[:, 0:1, ...]
1192+
fm2 = out_frames[:, 1:2, ...]
1193+
fm3 = out_frames[:, 2:3, ...]
1194+
fm4 = out_frames[:, 3:4, ...]
1195+
1196+
fixed_out_frames = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1197+
return new_cache, fixed_out_frames
1198+
1199+
dec_feat_map, scanned_out_frames = jax.lax.scan(scan_fn, dec_feat_map, x_scan)
1200+
1201+
# scanned_out_frames is (T-1, B, 4, H, W, C)
1202+
B = scanned_out_frames.shape[1]
1203+
T_minus_1 = scanned_out_frames.shape[0]
1204+
H, W, C = scanned_out_frames.shape[3], scanned_out_frames.shape[4], scanned_out_frames.shape[5]
1205+
1206+
# Swap back to (B, T-1, 4, H, W, C)
1207+
scanned_out_frames = jnp.swapaxes(scanned_out_frames, 0, 1)
1208+
# Flatten the temporal axes to (B, (T-1)*4, H, W, C)
1209+
scanned_out_frames = jnp.reshape(scanned_out_frames, (B, T_minus_1 * 4, H, W, C))
1210+
1211+
out = jnp.concatenate([out_0, scanned_out_frames], axis=1)
11951212

11961213
feat_cache._feat_map = dec_feat_map
11971214

0 commit comments

Comments
 (0)