Skip to content

Commit 4c28dc2

Browse files
committed
full refactor
1 parent 41a496a commit 4c28dc2

1 file changed

Lines changed: 114 additions & 83 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 114 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -485,15 +485,15 @@ def __init__(
485485
self.conv_shortcut = Identity()
486486
if in_dim != out_dim:
487487
self.conv_shortcut = WanCausalConv3d(
488-
rngs=rngs,
489-
in_channels=in_dim,
490-
out_channels=out_dim,
491-
kernel_size=1,
492-
mesh=mesh,
493-
dtype=dtype,
494-
weights_dtype=weights_dtype,
495-
precision=precision,
496-
)
488+
rngs=rngs,
489+
in_channels=in_dim,
490+
out_channels=out_dim,
491+
kernel_size=1,
492+
mesh=mesh,
493+
dtype=dtype,
494+
weights_dtype=weights_dtype,
495+
precision=precision,
496+
)
497497

498498
def initialize_cache(self, batch_size, height, width, dtype):
499499
"""Initialize cache for all convolutions."""
@@ -572,42 +572,42 @@ def __init__(
572572
)
573573

574574
def __call__(self, x: jax.Array):
575-
identity = x
576-
batch_size, time, height, width, channels = x.shape
577-
575+
identity = x
576+
batch_size, time, height, width, channels = x.shape
577+
578578
# Reshape to process all frames together
579-
x = x.reshape(batch_size * time, height, width, channels)
580-
x = self.norm(x)
581-
582-
qkv = self.to_qkv(x) # (B*T, H, W, C*3)
583-
584-
# Get actual shape after to_qkv to avoid using stale variables
585-
bt, h, w, c3 = qkv.shape
586-
587-
# Flatten spatial dimensions for attention
588-
qkv = qkv.reshape(bt, h * w, c3) # (B*T, H*W, C*3)
589-
qkv = jnp.transpose(qkv, (0, 2, 1)) # (B*T, C*3, H*W)
590-
591-
q, k, v = jnp.split(qkv, 3, axis=1) # Each: (B*T, C, H*W)
592-
q = jnp.transpose(q, (0, 2, 1)) # (B*T, H*W, C)
593-
k = jnp.transpose(k, (0, 2, 1)) # (B*T, H*W, C)
594-
v = jnp.transpose(v, (0, 2, 1)) # (B*T, H*W, C)
595-
596-
# Add head dimension for dot_product_attention
597-
q = jnp.expand_dims(q, 1) # (B*T, 1, H*W, C)
598-
k = jnp.expand_dims(k, 1) # (B*T, 1, H*W, C)
599-
v = jnp.expand_dims(v, 1) # (B*T, 1, H*W, C)
600-
601-
x = jax.nn.dot_product_attention(q, k, v) # (B*T, 1, H*W, C)
602-
x = jnp.squeeze(x, 1) # (B*T, H*W, C)
603-
604-
# Reshape back to spatial dimensions
605-
x = x.reshape(bt, h, w, channels)
606-
x = self.proj(x)
607-
579+
x = x.reshape(batch_size * time, height, width, channels)
580+
x = self.norm(x)
581+
582+
qkv = self.to_qkv(x) # (B*T, H, W, C*3)
583+
584+
# Get actual shape after to_qkv to avoid using stale variables
585+
bt, h, w, c3 = qkv.shape
586+
587+
# Flatten spatial dimensions for attention
588+
qkv = qkv.reshape(bt, h * w, c3) # (B*T, H*W, C*3)
589+
qkv = jnp.transpose(qkv, (0, 2, 1)) # (B*T, C*3, H*W)
590+
591+
q, k, v = jnp.split(qkv, 3, axis=1) # Each: (B*T, C, H*W)
592+
q = jnp.transpose(q, (0, 2, 1)) # (B*T, H*W, C)
593+
k = jnp.transpose(k, (0, 2, 1)) # (B*T, H*W, C)
594+
v = jnp.transpose(v, (0, 2, 1)) # (B*T, H*W, C)
595+
596+
# Add head dimension for dot_product_attention
597+
q = jnp.expand_dims(q, 1) # (B*T, 1, H*W, C)
598+
k = jnp.expand_dims(k, 1) # (B*T, 1, H*W, C)
599+
v = jnp.expand_dims(v, 1) # (B*T, 1, H*W, C)
600+
601+
x = jax.nn.dot_product_attention(q, k, v) # (B*T, 1, H*W, C)
602+
x = jnp.squeeze(x, 1) # (B*T, H*W, C)
603+
604+
# Reshape back to spatial dimensions
605+
x = x.reshape(bt, h, w, channels)
606+
x = self.proj(x)
607+
608608
# Reshape back to original shape
609-
x = x.reshape(batch_size, time, height, width, channels)
610-
return x + identity
609+
x = x.reshape(batch_size, time, height, width, channels)
610+
return x + identity
611611

612612

613613
class WanMidBlock(nnx.Module):
@@ -626,18 +626,18 @@ def __init__(
626626
self.dim = dim
627627
self.resnets = nnx.List(
628628
[
629-
WanResidualBlock(
630-
in_dim=dim,
631-
out_dim=dim,
632-
rngs=rngs,
633-
dropout=dropout,
634-
non_linearity=non_linearity,
635-
mesh=mesh,
636-
dtype=dtype,
637-
weights_dtype=weights_dtype,
638-
precision=precision,
639-
)
640-
]
629+
WanResidualBlock(
630+
in_dim=dim,
631+
out_dim=dim,
632+
rngs=rngs,
633+
dropout=dropout,
634+
non_linearity=non_linearity,
635+
mesh=mesh,
636+
dtype=dtype,
637+
weights_dtype=weights_dtype,
638+
precision=precision,
639+
)
640+
]
641641
)
642642
self.attentions = nnx.List([])
643643
for _ in range(num_layers):
@@ -991,18 +991,18 @@ def __init__(
991991
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
992992
self.up_blocks.append(
993993
WanUpBlock(
994-
in_dim=in_dim,
995-
out_dim=out_dim,
996-
num_res_blocks=num_res_blocks,
997-
dropout=dropout,
998-
upsample_mode=upsample_mode,
999-
non_linearity=non_linearity,
1000-
rngs=rngs,
1001-
mesh=mesh,
1002-
dtype=dtype,
1003-
weights_dtype=weights_dtype,
1004-
precision=precision,
1005-
)
994+
in_dim=in_dim,
995+
out_dim=out_dim,
996+
num_res_blocks=num_res_blocks,
997+
dropout=dropout,
998+
upsample_mode=upsample_mode,
999+
non_linearity=non_linearity,
1000+
rngs=rngs,
1001+
mesh=mesh,
1002+
dtype=dtype,
1003+
weights_dtype=weights_dtype,
1004+
precision=precision,
1005+
)
10061006
)
10071007

10081008
self.norm_out = WanRMS_norm(
@@ -1176,22 +1176,44 @@ def encode(
11761176
if x.shape[-1] != 3:
11771177
x = jnp.transpose(x, (0, 2, 3, 4, 1))
11781178

1179-
x_scan = jnp.swapaxes(x, 0, 1) # (B, T, H, W, C) -> (T, B, H, W, C)
1179+
# Calculate temporal downsampling factor
1180+
temporal_downsample_factor = 1
1181+
for ds in self.temperal_downsample:
1182+
if ds:
1183+
temporal_downsample_factor *= 2
1184+
11801185
b, t, h, w, c = x.shape
1186+
1187+
# Process frames in chunks that match temporal downsampling
1188+
# This prevents frames from being downsampled to 0
1189+
chunk_size = temporal_downsample_factor
1190+
1191+
# Pad time dimension if needed to make it divisible by chunk_size
1192+
if t % chunk_size != 0:
1193+
pad_frames = chunk_size - (t % chunk_size)
1194+
x = jnp.pad(x, ((0, 0), (0, pad_frames), (0, 0), (0, 0), (0, 0)), mode='edge')
1195+
t = x.shape[1]
1196+
1197+
# Reshape to process chunks: (B, T, H, W, C) -> (T//chunk_size, B, chunk_size, H, W, C)
1198+
x_chunks = x.reshape(b, t // chunk_size, chunk_size, h, w, c)
1199+
x_scan = jnp.swapaxes(x_chunks, 0, 1) # -> (T//chunk_size, B, chunk_size, H, W, C)
1200+
11811201
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
11821202

1183-
def scan_fn(carry, input_slice):
1184-
"""Scan function processes one frame at a time."""
1185-
# Expand time dimension for Conv3d compatibility
1186-
input_slice = jnp.expand_dims(input_slice, 1) # (B, H, W, C) -> (B, 1, H, W, C)
1187-
out_slice, new_carry = self.encoder(input_slice, carry)
1188-
# Squeeze time dimension for scan stacking
1189-
out_slice = jnp.squeeze(out_slice, 1) # (B, 1, H', W', C') -> (B, H', W', C')
1190-
return new_carry, out_slice
1203+
def scan_fn(carry, input_chunk):
1204+
"""Scan function processes one chunk of frames at a time."""
1205+
# input_chunk shape: (B, chunk_size, H, W, C)
1206+
out_chunk, new_carry = self.encoder(input_chunk, carry)
1207+
return new_carry, out_chunk
11911208

11921209
# Use jax.lax.scan for JIT-compilable temporal iteration
1193-
final_cache, encoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
1194-
encoded = jnp.swapaxes(encoded_frames, 0, 1) # (T, B, H', W', C') -> (B, T, H', W', C')
1210+
final_cache, encoded_chunks = jax.lax.scan(scan_fn, init_cache, x_scan)
1211+
# encoded_chunks shape: (T//chunk_size, B, T_out_per_chunk, H', W', C')
1212+
1213+
# Reshape back: (T//chunk_size, B, T_out, H', W', C') -> (B, T_total, H', W', C')
1214+
n_chunks, batch, t_per_chunk, h_out, w_out, c_out = encoded_chunks.shape
1215+
encoded = jnp.transpose(encoded_chunks, (1, 0, 2, 3, 4, 5)) # (B, n_chunks, T_out, H', W', C')
1216+
encoded = encoded.reshape(batch, n_chunks * t_per_chunk, h_out, w_out, c_out)
11951217

11961218
# Apply quantization convolution
11971219
enc, _ = self.quant_conv(encoded)
@@ -1221,9 +1243,18 @@ def decode(
12211243

12221244
# Apply post-quantization convolution
12231245
x, _ = self.post_quant_conv(z)
1224-
x_scan = jnp.swapaxes(x, 0, 1) # (B, T, H, W, C) -> (T, B, H, W, C)
1225-
1246+
1247+
# Calculate temporal upsampling factor
1248+
temporal_upsample_factor = 1
1249+
for us in self.temporal_upsample:
1250+
if us:
1251+
temporal_upsample_factor *= 2
1252+
12261253
b, t, h, w, c = x.shape
1254+
1255+
# For decoder, we still process one frame at a time but output will be upsampled
1256+
x_scan = jnp.swapaxes(x, 0, 1) # (B, T, H, W, C) -> (T, B, H, W, C)
1257+
12271258
init_cache = self.decoder.init_cache(b, h, w, x.dtype)
12281259

12291260
def scan_fn(carry, input_slice):
@@ -1238,11 +1269,11 @@ def scan_fn(carry, input_slice):
12381269
# Use jax.lax.scan for JIT-compilable temporal iteration
12391270
final_cache, decoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
12401271

1241-
# decoded_frames shape: (T_lat, B, 4, H, W, C)
1242-
# Transpose to (B, T_lat, 4, H, W, C)
1272+
# decoded_frames shape: (T_lat, B, T_upsample, H, W, C)
1273+
# Transpose to (B, T_lat, T_upsample, H, W, C)
12431274
decoded = jnp.transpose(decoded_frames, (1, 0, 2, 3, 4, 5))
12441275

1245-
# Reshape to (B, T_lat*4, H, W, C)
1276+
# Reshape to (B, T_lat * T_upsample, H, W, C)
12461277
b, t_lat, t_sub, h, w, c = decoded.shape
12471278
decoded = decoded.reshape(b, t_lat * t_sub, h, w, c)
12481279

0 commit comments

Comments
 (0)