Skip to content

Commit 24b4e57

Browse files
committed
nnx.jit for encode decode
1 parent 15f9c27 commit 24b4e57

1 file changed

Lines changed: 99 additions & 85 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 99 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,44 +1126,51 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11261126
t = x.shape[1]
11271127
enc_feat_map = feat_cache._enc_feat_map
11281128

1129-
# 1. Evaluate the first frame manually to establish the initial cache with JAX Arrays.
1130-
# This prevents jax.lax.scan from crashing on type mismatch between None and ShapedArray.
1131-
out_0, enc_feat_map, _ = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=0)
1132-
out = out_0
1133-
1134-
# 2. Evaluate the second chunk (4 frames) manually to stabilize WanCausalConv3d caches to T=2.
1135-
# WanCausalConv3d uses cache_x = x[:, -2:]. After 1 frame, cache is T=1. After 4 frames, it stabilizes to T=2.
1136-
if t > 1:
1137-
out_1, enc_feat_map, _ = self.encoder(x[:, 1:5, :, :, :], feat_cache=enc_feat_map, feat_idx=0)
1138-
out = jnp.concatenate([out_0, out_1], axis=1)
1139-
1140-
# 3. Process remaining frames in chunks of 4 using jax.lax.scan
1141-
if t > 5:
1142-
x_rest = x[:, 5:, :, :, :]
1143-
B, T_rest, H, W, C = x_rest.shape
1144-
num_chunks = T_rest // 4
1145-
1146-
# Reshape to (B, num_chunks, 4, H, W, C)
1147-
x_chunks = jnp.reshape(x_rest, (B, num_chunks, 4, H, W, C))
1148-
1149-
# Swap axes for scan traversal: (num_chunks, B, 4, H, W, C)
1150-
x_scan = jnp.swapaxes(x_chunks, 0, 1)
1151-
1152-
def scan_fn(carry_cache, input_chunk):
1153-
# input_chunk shape: (B, 4, H, W, C)
1154-
out_chunk, new_cache, _ = self.encoder(input_chunk, feat_cache=carry_cache, feat_idx=0)
1155-
# out_chunk shape: (B, 1, H', W', C')
1156-
return new_cache, out_chunk
1129+
@nnx.jit
1130+
def encode_sequence(encoder, x_seq, current_enc_feat_map):
1131+
t_seq = x_seq.shape[1]
1132+
# 1. Evaluate the first frame manually to establish the initial cache with JAX Arrays.
1133+
# This prevents jax.lax.scan from crashing on type mismatch between None and ShapedArray.
1134+
out_0, current_enc_feat_map, _ = encoder(x_seq[:, :1, :, :, :], feat_cache=current_enc_feat_map, feat_idx=0)
1135+
out_seq = out_0
1136+
1137+
# 2. Evaluate the second chunk (4 frames) manually to stabilize WanCausalConv3d caches to T=2.
1138+
# WanCausalConv3d uses cache_x = x[:, -2:]. After 1 frame, cache is T=1. After 4 frames, it stabilizes to T=2.
1139+
if t_seq > 1:
1140+
out_1, current_enc_feat_map, _ = encoder(x_seq[:, 1:5, :, :, :], feat_cache=current_enc_feat_map, feat_idx=0)
1141+
out_seq = jnp.concatenate([out_0, out_1], axis=1)
1142+
1143+
# 3. Process remaining frames in chunks of 4 using jax.lax.scan
1144+
if t_seq > 5:
1145+
x_rest = x_seq[:, 5:, :, :, :]
1146+
B, T_rest, H, W, C = x_rest.shape
1147+
num_chunks = T_rest // 4
1148+
1149+
# Reshape to (B, num_chunks, 4, H, W, C)
1150+
x_chunks = jnp.reshape(x_rest, (B, num_chunks, 4, H, W, C))
11571151

1158-
enc_feat_map, scanned_out_chunks = jax.lax.scan(scan_fn, enc_feat_map, x_scan)
1159-
1160-
# scanned_out_chunks shape: (num_chunks, B, 1, H', W', C')
1161-
scanned_out_chunks = jnp.swapaxes(scanned_out_chunks, 0, 1)
1162-
1163-
B_out, _, _, H_out, W_out, C_out = scanned_out_chunks.shape
1164-
scanned_out_chunks = jnp.reshape(scanned_out_chunks, (B_out, num_chunks, H_out, W_out, C_out))
1152+
# Swap axes for scan traversal: (num_chunks, B, 4, H, W, C)
1153+
x_scan = jnp.swapaxes(x_chunks, 0, 1)
1154+
1155+
def scan_fn(carry_cache, input_chunk):
1156+
# input_chunk shape: (B, 4, H, W, C)
1157+
out_chunk, new_cache, _ = encoder(input_chunk, feat_cache=carry_cache, feat_idx=0)
1158+
# out_chunk shape: (B, 1, H', W', C')
1159+
return new_cache, out_chunk
1160+
1161+
current_enc_feat_map, scanned_out_chunks = jax.lax.scan(scan_fn, current_enc_feat_map, x_scan)
1162+
1163+
# scanned_out_chunks shape: (num_chunks, B, 1, H', W', C')
1164+
scanned_out_chunks = jnp.swapaxes(scanned_out_chunks, 0, 1)
1165+
1166+
B_out, _, _, H_out, W_out, C_out = scanned_out_chunks.shape
1167+
scanned_out_chunks = jnp.reshape(scanned_out_chunks, (B_out, num_chunks, H_out, W_out, C_out))
1168+
1169+
out_seq = jnp.concatenate([out_seq, scanned_out_chunks], axis=1)
11651170

1166-
out = jnp.concatenate([out, scanned_out_chunks], axis=1)
1171+
return out_seq, current_enc_feat_map
1172+
1173+
out, enc_feat_map = encode_sequence(self.encoder, x, enc_feat_map)
11671174

11681175
# 3. Update back to the wrapper object if needed
11691176
feat_cache._enc_feat_map = enc_feat_map
@@ -1193,57 +1200,64 @@ def _decode(
11931200

11941201
dec_feat_map = feat_cache._feat_map
11951202

1196-
# 1. Evaluate the first frame manually (Cache: None -> RepSentinel/ShapedArray)
1197-
# The decoder returns 1 frame on the first step.
1198-
out_0, dec_feat_map, _ = self.decoder(x[:, 0:1, :, :, :], feat_cache=dec_feat_map, feat_idx=0)
1199-
out = out_0
1200-
1201-
# 2. Evaluate the second frame manually (Cache: RepSentinel -> ShapedArray)
1202-
# This ensures that ALL cache components are ShapredArrays before entering jax.lax.scan,
1203-
# preventing TraceContext errors due to type mismatches.
1204-
if iter_ > 1:
1205-
out_1, dec_feat_map, _ = self.decoder(x[:, 1:2, :, :, :], feat_cache=dec_feat_map, feat_idx=0)
1203+
@nnx.jit
1204+
def decode_sequence(decoder, x_seq, current_dec_feat_map):
1205+
iter_s = x_seq.shape[1]
12061206

1207-
# Bypass an issue where frame[1] should be frame[2] and vice versa.
1208-
fm1 = out_1[:, 0:1, ...]
1209-
fm2 = out_1[:, 1:2, ...]
1210-
fm3 = out_1[:, 2:3, ...]
1211-
fm4 = out_1[:, 3:4, ...]
1212-
out_1_fixed = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1213-
out = jnp.concatenate([out_0, out_1_fixed], axis=1)
1214-
1215-
# 3. Process remaining frames using jax.lax.scan (requires homogenous output and carry shapes)
1216-
if iter_ > 2:
1217-
x_rest = x[:, 2:, :, :, :]
1218-
x_scan = jnp.swapaxes(x_rest, 0, 1) # (T-2, B, H, W, C)
1219-
1220-
def scan_fn(carry_cache, input_frame):
1221-
input_frame = jnp.expand_dims(input_frame, 1) # (B, 1, H, W, C)
1222-
out_frames, new_cache, _ = self.decoder(input_frame, feat_cache=carry_cache, feat_idx=0)
1223-
1224-
# Bypass an issue where frame[1] should be frame[2] and vice versa.
1225-
fm1 = out_frames[:, 0:1, ...]
1226-
fm2 = out_frames[:, 1:2, ...]
1227-
fm3 = out_frames[:, 2:3, ...]
1228-
fm4 = out_frames[:, 3:4, ...]
1207+
# 1. Evaluate the first frame manually (Cache: None -> RepSentinel/ShapedArray)
1208+
# The decoder returns 1 frame on the first step.
1209+
out_0, current_dec_feat_map, _ = decoder(x_seq[:, 0:1, :, :, :], feat_cache=current_dec_feat_map, feat_idx=0)
1210+
out_seq = out_0
1211+
1212+
# 2. Evaluate the second frame manually (Cache: RepSentinel -> ShapedArray)
1213+
# This ensures that ALL cache components are ShapredArrays before entering jax.lax.scan,
1214+
# preventing TraceContext errors due to type mismatches.
1215+
if iter_s > 1:
1216+
out_1, current_dec_feat_map, _ = decoder(x_seq[:, 1:2, :, :, :], feat_cache=current_dec_feat_map, feat_idx=0)
12291217

1230-
fixed_out_frames = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1231-
return new_cache, fixed_out_frames
1232-
1233-
dec_feat_map, scanned_out_frames = jax.lax.scan(scan_fn, dec_feat_map, x_scan)
1234-
1235-
# scanned_out_frames is (T-2, B, 4, H, W, C)
1236-
B = scanned_out_frames.shape[1]
1237-
T_minus_2 = scanned_out_frames.shape[0]
1238-
H, W, C = scanned_out_frames.shape[3], scanned_out_frames.shape[4], scanned_out_frames.shape[5]
1239-
1240-
# Swap back to (B, T-2, 4, H, W, C)
1241-
scanned_out_frames = jnp.swapaxes(scanned_out_frames, 0, 1)
1242-
# Flatten the temporal axes to (B, (T-2)*4, H, W, C)
1243-
scanned_out_frames = jnp.reshape(scanned_out_frames, (B, T_minus_2 * 4, H, W, C))
1244-
1245-
out = jnp.concatenate([out, scanned_out_frames], axis=1)
1246-
1218+
# Bypass an issue where frame[1] should be frame[2] and vice versa.
1219+
fm1 = out_1[:, 0:1, ...]
1220+
fm2 = out_1[:, 1:2, ...]
1221+
fm3 = out_1[:, 2:3, ...]
1222+
fm4 = out_1[:, 3:4, ...]
1223+
out_1_fixed = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1224+
out_seq = jnp.concatenate([out_0, out_1_fixed], axis=1)
1225+
1226+
# 3. Process remaining frames using jax.lax.scan (requires homogenous output and carry shapes)
1227+
if iter_s > 2:
1228+
x_rest = x_seq[:, 2:, :, :, :]
1229+
x_scan = jnp.swapaxes(x_rest, 0, 1) # (T-2, B, H, W, C)
1230+
1231+
def scan_fn(carry_cache, input_frame):
1232+
input_frame = jnp.expand_dims(input_frame, 1) # (B, 1, H, W, C)
1233+
out_frames, new_cache, _ = decoder(input_frame, feat_cache=carry_cache, feat_idx=0)
1234+
1235+
# Bypass an issue where frame[1] should be frame[2] and vice versa.
1236+
fm1 = out_frames[:, 0:1, ...]
1237+
fm2 = out_frames[:, 1:2, ...]
1238+
fm3 = out_frames[:, 2:3, ...]
1239+
fm4 = out_frames[:, 3:4, ...]
1240+
1241+
fixed_out_frames = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
1242+
return new_cache, fixed_out_frames
1243+
1244+
current_dec_feat_map, scanned_out_frames = jax.lax.scan(scan_fn, current_dec_feat_map, x_scan)
1245+
1246+
# scanned_out_frames is (T-2, B, 4, H, W, C)
1247+
B = scanned_out_frames.shape[1]
1248+
T_minus_2 = scanned_out_frames.shape[0]
1249+
H, W, C = scanned_out_frames.shape[3], scanned_out_frames.shape[4], scanned_out_frames.shape[5]
1250+
1251+
# Swap back to (B, T-2, 4, H, W, C)
1252+
scanned_out_frames = jnp.swapaxes(scanned_out_frames, 0, 1)
1253+
# Flatten the temporal axes to (B, (T-2)*4, H, W, C)
1254+
scanned_out_frames = jnp.reshape(scanned_out_frames, (B, T_minus_2 * 4, H, W, C))
1255+
1256+
out_seq = jnp.concatenate([out_seq, scanned_out_frames], axis=1)
1257+
1258+
return out_seq, current_dec_feat_map
1259+
1260+
out, dec_feat_map = decode_sequence(self.decoder, x, dec_feat_map)
12471261
feat_cache._feat_map = dec_feat_map
12481262

12491263
out = jnp.clip(out, min=-1.0, max=1.0)

0 commit comments

Comments
 (0)