Skip to content

Commit 9e25d57

Browse files
committed
full refactor
1 parent 5b052a7 commit 9e25d57

1 file changed

Lines changed: 77 additions & 77 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def __call__(
391391
new_cache = {}
392392

393393
if self.mode == "upsample2d":
394-
b, t, h, w, c = x.shape
394+
b, t, h, w, c = x.shape
395395
x = x.reshape(b * t, h, w, c)
396396
x = self.resample(x)
397397
h_new, w_new, c_new = x.shape[1:]
@@ -403,14 +403,14 @@ def __call__(
403403

404404
b, t, h, w, c = x.shape
405405
x = x.reshape(b, t, h, w, 2, c // 2)
406-
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
406+
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
407407
x = x.reshape(b, t * 2, h, w, c // 2)
408408

409409
b, t, h, w, c = x.shape
410-
x = x.reshape(b * t, h, w, c)
411-
x = self.resample(x)
412-
h_new, w_new, c_new = x.shape[1:]
413-
x = x.reshape(b, t, h_new, w_new, c_new)
410+
x = x.reshape(b * t, h, w, c)
411+
x = self.resample(x)
412+
h_new, w_new, c_new = x.shape[1:]
413+
x = x.reshape(b, t, h_new, w_new, c_new)
414414

415415
elif self.mode == "downsample2d":
416416
b, t, h, w, c = x.shape
@@ -429,7 +429,7 @@ def __call__(
429429
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
430430
new_cache["time_conv"] = tc_cache
431431

432-
else:
432+
else:
433433
if hasattr(self, "resample"):
434434
if isinstance(self.resample, Identity):
435435
x, _ = self.resample(x, None)
@@ -485,15 +485,15 @@ def __init__(
485485
self.conv_shortcut = Identity()
486486
if in_dim != out_dim:
487487
self.conv_shortcut = WanCausalConv3d(
488-
rngs=rngs,
489-
in_channels=in_dim,
490-
out_channels=out_dim,
491-
kernel_size=1,
492-
mesh=mesh,
493-
dtype=dtype,
494-
weights_dtype=weights_dtype,
495-
precision=precision,
496-
)
488+
rngs=rngs,
489+
in_channels=in_dim,
490+
out_channels=out_dim,
491+
kernel_size=1,
492+
mesh=mesh,
493+
dtype=dtype,
494+
weights_dtype=weights_dtype,
495+
precision=precision,
496+
)
497497

498498
def initialize_cache(self, batch_size, height, width, dtype):
499499
"""Initialize cache for all convolutions."""
@@ -571,43 +571,43 @@ def __init__(
571571
precision=precision,
572572
)
573573

574-
def __call__(self, x: jax.Array):
575-
identity = x
576-
batch_size, time, height, width, channels = x.shape
577-
578-
# Reshape to process all frames together
579-
x = x.reshape(batch_size * time, height, width, channels)
580-
x = self.norm(x)
581-
582-
qkv = self.to_qkv(x) # (B*T, H, W, C*3)
583-
584-
# Get actual shape after to_qkv to avoid using stale variables
585-
bt, h, w, c3 = qkv.shape
586-
587-
# Flatten spatial dimensions for attention
588-
qkv = qkv.reshape(bt, h * w, c3) # (B*T, H*W, C*3)
589-
qkv = jnp.transpose(qkv, (0, 2, 1)) # (B*T, C*3, H*W)
590-
591-
q, k, v = jnp.split(qkv, 3, axis=1) # Each: (B*T, C, H*W)
592-
q = jnp.transpose(q, (0, 2, 1)) # (B*T, H*W, C)
593-
k = jnp.transpose(k, (0, 2, 1)) # (B*T, H*W, C)
594-
v = jnp.transpose(v, (0, 2, 1)) # (B*T, H*W, C)
595-
596-
# Add head dimension for dot_product_attention
597-
q = jnp.expand_dims(q, 1) # (B*T, 1, H*W, C)
598-
k = jnp.expand_dims(k, 1) # (B*T, 1, H*W, C)
599-
v = jnp.expand_dims(v, 1) # (B*T, 1, H*W, C)
600-
601-
x = jax.nn.dot_product_attention(q, k, v) # (B*T, 1, H*W, C)
602-
x = jnp.squeeze(x, 1) # (B*T, H*W, C)
603-
604-
# Reshape back to spatial dimensions
605-
x = x.reshape(bt, h, w, channels)
606-
x = self.proj(x)
607-
608-
# Reshape back to original shape
609-
x = x.reshape(batch_size, time, height, width, channels)
610-
return x + identity
574+
def __call__(self, x: jax.Array):
575+
identity = x
576+
batch_size, time, height, width, channels = x.shape
577+
578+
# Reshape to process all frames together
579+
x = x.reshape(batch_size * time, height, width, channels)
580+
x = self.norm(x)
581+
582+
qkv = self.to_qkv(x) # (B*T, H, W, C*3)
583+
584+
# Get actual shape after to_qkv to avoid using stale variables
585+
bt, h, w, c3 = qkv.shape
586+
587+
# Flatten spatial dimensions for attention
588+
qkv = qkv.reshape(bt, h * w, c3) # (B*T, H*W, C*3)
589+
qkv = jnp.transpose(qkv, (0, 2, 1)) # (B*T, C*3, H*W)
590+
591+
q, k, v = jnp.split(qkv, 3, axis=1) # Each: (B*T, C, H*W)
592+
q = jnp.transpose(q, (0, 2, 1)) # (B*T, H*W, C)
593+
k = jnp.transpose(k, (0, 2, 1)) # (B*T, H*W, C)
594+
v = jnp.transpose(v, (0, 2, 1)) # (B*T, H*W, C)
595+
596+
# Add head dimension for dot_product_attention
597+
q = jnp.expand_dims(q, 1) # (B*T, 1, H*W, C)
598+
k = jnp.expand_dims(k, 1) # (B*T, 1, H*W, C)
599+
v = jnp.expand_dims(v, 1) # (B*T, 1, H*W, C)
600+
601+
x = jax.nn.dot_product_attention(q, k, v) # (B*T, 1, H*W, C)
602+
x = jnp.squeeze(x, 1) # (B*T, H*W, C)
603+
604+
# Reshape back to spatial dimensions
605+
x = x.reshape(bt, h, w, channels)
606+
x = self.proj(x)
607+
608+
# Reshape back to original shape
609+
x = x.reshape(batch_size, time, height, width, channels)
610+
return x + identity
611611

612612

613613
class WanMidBlock(nnx.Module):
@@ -626,18 +626,18 @@ def __init__(
626626
self.dim = dim
627627
self.resnets = nnx.List(
628628
[
629-
WanResidualBlock(
630-
in_dim=dim,
631-
out_dim=dim,
632-
rngs=rngs,
633-
dropout=dropout,
634-
non_linearity=non_linearity,
635-
mesh=mesh,
636-
dtype=dtype,
637-
weights_dtype=weights_dtype,
638-
precision=precision,
639-
)
640-
]
629+
WanResidualBlock(
630+
in_dim=dim,
631+
out_dim=dim,
632+
rngs=rngs,
633+
dropout=dropout,
634+
non_linearity=non_linearity,
635+
mesh=mesh,
636+
dtype=dtype,
637+
weights_dtype=weights_dtype,
638+
precision=precision,
639+
)
640+
]
641641
)
642642
self.attentions = nnx.List([])
643643
for _ in range(num_layers):
@@ -991,18 +991,18 @@ def __init__(
991991
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
992992
self.up_blocks.append(
993993
WanUpBlock(
994-
in_dim=in_dim,
995-
out_dim=out_dim,
996-
num_res_blocks=num_res_blocks,
997-
dropout=dropout,
998-
upsample_mode=upsample_mode,
999-
non_linearity=non_linearity,
1000-
rngs=rngs,
1001-
mesh=mesh,
1002-
dtype=dtype,
1003-
weights_dtype=weights_dtype,
1004-
precision=precision,
1005-
)
994+
in_dim=in_dim,
995+
out_dim=out_dim,
996+
num_res_blocks=num_res_blocks,
997+
dropout=dropout,
998+
upsample_mode=upsample_mode,
999+
non_linearity=non_linearity,
1000+
rngs=rngs,
1001+
mesh=mesh,
1002+
dtype=dtype,
1003+
weights_dtype=weights_dtype,
1004+
precision=precision,
1005+
)
10061006
)
10071007

10081008
self.norm_out = WanRMS_norm(

0 commit comments

Comments
 (0)