Skip to content

Commit 5b052a7

Browse files
committed
full refactor
1 parent f004d9f commit 5b052a7

1 file changed

Lines changed: 37 additions & 33 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -571,39 +571,43 @@ def __init__(
571571
precision=precision,
572572
)
573573

574-
def __call__(self, x: jax.Array):
575-
identity = x
576-
batch_size, time, height, width, channels = x.shape
577-
578-
# Reshape to process all frames together
579-
x = x.reshape(batch_size * time, height, width, channels)
580-
x = self.norm(x)
581-
582-
qkv = self.to_qkv(x) # (B*T, H, W, C*3)
583-
# Flatten spatial dimensions for attention
584-
qkv = qkv.reshape(batch_size * time, -1, channels * 3) # (B*T, H*W, C*3)
585-
qkv = jnp.transpose(qkv, (0, 2, 1)) # (B*T, C*3, H*W)
586-
587-
q, k, v = jnp.split(qkv, 3, axis=1) # Each: (B*T, C, H*W)
588-
q = jnp.transpose(q, (0, 2, 1)) # (B*T, H*W, C)
589-
k = jnp.transpose(k, (0, 2, 1)) # (B*T, H*W, C)
590-
v = jnp.transpose(v, (0, 2, 1)) # (B*T, H*W, C)
591-
592-
# Add head dimension for dot_product_attention
593-
q = jnp.expand_dims(q, 1) # (B*T, 1, H*W, C)
594-
k = jnp.expand_dims(k, 1) # (B*T, 1, H*W, C)
595-
v = jnp.expand_dims(v, 1) # (B*T, 1, H*W, C)
596-
597-
x = jax.nn.dot_product_attention(q, k, v) # (B*T, 1, H*W, C)
598-
x = jnp.squeeze(x, 1) # (B*T, H*W, C)
599-
600-
# Reshape back to spatial dimensions
601-
x = x.reshape(batch_size * time, height, width, channels)
602-
x = self.proj(x)
603-
604-
# Reshape back to original shape
605-
x = x.reshape(batch_size, time, height, width, channels)
606-
return x + identity
574+
def __call__(self, x: jax.Array):
575+
identity = x
576+
batch_size, time, height, width, channels = x.shape
577+
578+
# Reshape to process all frames together
579+
x = x.reshape(batch_size * time, height, width, channels)
580+
x = self.norm(x)
581+
582+
qkv = self.to_qkv(x) # (B*T, H, W, C*3)
583+
584+
# Get actual shape after to_qkv to avoid using stale variables
585+
bt, h, w, c3 = qkv.shape
586+
587+
# Flatten spatial dimensions for attention
588+
qkv = qkv.reshape(bt, h * w, c3) # (B*T, H*W, C*3)
589+
qkv = jnp.transpose(qkv, (0, 2, 1)) # (B*T, C*3, H*W)
590+
591+
q, k, v = jnp.split(qkv, 3, axis=1) # Each: (B*T, C, H*W)
592+
q = jnp.transpose(q, (0, 2, 1)) # (B*T, H*W, C)
593+
k = jnp.transpose(k, (0, 2, 1)) # (B*T, H*W, C)
594+
v = jnp.transpose(v, (0, 2, 1)) # (B*T, H*W, C)
595+
596+
# Add head dimension for dot_product_attention
597+
q = jnp.expand_dims(q, 1) # (B*T, 1, H*W, C)
598+
k = jnp.expand_dims(k, 1) # (B*T, 1, H*W, C)
599+
v = jnp.expand_dims(v, 1) # (B*T, 1, H*W, C)
600+
601+
x = jax.nn.dot_product_attention(q, k, v) # (B*T, 1, H*W, C)
602+
x = jnp.squeeze(x, 1) # (B*T, H*W, C)
603+
604+
# Reshape back to spatial dimensions
605+
x = x.reshape(bt, h, w, channels)
606+
x = self.proj(x)
607+
608+
# Reshape back to original shape
609+
x = x.reshape(batch_size, time, height, width, channels)
610+
return x + identity
607611

608612

609613
class WanMidBlock(nnx.Module):

0 commit comments

Comments
 (0)