Skip to content

Commit 1108fdd

Browse files
committed
full refactor
1 parent 983841d commit 1108fdd

1 file changed

Lines changed: 56 additions & 124 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 56 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -672,101 +672,6 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
672672
return x, new_cache
673673

674674

675-
class WanDownBlock(nnx.Module):
676-
def __init__(
677-
self,
678-
in_dim: int,
679-
out_dim: int,
680-
num_res_blocks: int,
681-
rngs: nnx.Rngs,
682-
dropout: float = 0.0,
683-
downsample_mode: Optional[str] = None,
684-
add_attention: bool = False,
685-
non_linearity: str = "silu",
686-
mesh: jax.sharding.Mesh = None,
687-
dtype: jnp.dtype = jnp.float32,
688-
weights_dtype: jnp.dtype = jnp.float32,
689-
precision: jax.lax.Precision = None,
690-
):
691-
self.layers = nnx.List([])
692-
current_dim = in_dim
693-
for _ in range(num_res_blocks):
694-
self.layers.append(
695-
WanResidualBlock(
696-
in_dim=current_dim,
697-
out_dim=out_dim,
698-
dropout=dropout,
699-
non_linearity=non_linearity,
700-
rngs=rngs,
701-
mesh=mesh,
702-
dtype=dtype,
703-
weights_dtype=weights_dtype,
704-
precision=precision,
705-
)
706-
)
707-
if add_attention:
708-
self.layers.append(
709-
WanAttentionBlock(
710-
dim=out_dim,
711-
rngs=rngs,
712-
mesh=mesh,
713-
dtype=dtype,
714-
weights_dtype=weights_dtype,
715-
precision=precision,
716-
)
717-
)
718-
current_dim = out_dim
719-
720-
if downsample_mode is not None:
721-
self.layers.append(
722-
WanResample(
723-
out_dim,
724-
mode=downsample_mode,
725-
rngs=rngs,
726-
mesh=mesh,
727-
dtype=dtype,
728-
weights_dtype=weights_dtype,
729-
precision=precision,
730-
)
731-
)
732-
733-
def initialize_cache(self, batch_size, height, width, dtype):
734-
"""Initialize cache for all layers."""
735-
cache = {"layers": []}
736-
h_curr, w_curr = height, width
737-
for layer in self.layers:
738-
if isinstance(layer, WanResidualBlock):
739-
cache["layers"].append(
740-
layer.initialize_cache(batch_size, h_curr, w_curr, dtype)
741-
)
742-
elif isinstance(layer, WanResample):
743-
cache["layers"].append(
744-
layer.initialize_cache(batch_size, h_curr, w_curr, dtype)
745-
)
746-
if layer.mode in ["downsample2d", "downsample3d"]:
747-
h_curr, w_curr = h_curr // 2, w_curr // 2
748-
else: # Attention
749-
cache["layers"].append(None)
750-
return cache
751-
752-
def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
753-
"""Pure function: returns (output, new_cache)."""
754-
if cache is None:
755-
cache = {}
756-
new_cache = {"layers": []}
757-
758-
current_caches = cache.get("layers", [None] * len(self.layers))
759-
for i, layer in enumerate(self.layers):
760-
if isinstance(layer, (WanResidualBlock, WanResample)):
761-
x, c = layer(x, current_caches[i])
762-
new_cache["layers"].append(c)
763-
else: # Attention
764-
x = layer(x)
765-
new_cache["layers"].append(None)
766-
767-
return x, new_cache
768-
769-
770675
class WanUpBlock(nnx.Module):
771676
def __init__(
772677
self,
@@ -883,27 +788,45 @@ def __init__(
883788

884789
self.down_blocks = nnx.List([])
885790
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
886-
add_attention = scale in attn_scales
887-
downsample_mode = None
791+
for _ in range(num_res_blocks):
792+
self.down_blocks.append(
793+
WanResidualBlock(
794+
in_dim=in_dim,
795+
out_dim=out_dim,
796+
dropout=dropout,
797+
rngs=rngs,
798+
mesh=mesh,
799+
dtype=dtype,
800+
weights_dtype=weights_dtype,
801+
precision=precision,
802+
)
803+
)
804+
if scale in attn_scales:
805+
self.down_blocks.append(
806+
WanAttentionBlock(
807+
dim=out_dim,
808+
rngs=rngs,
809+
mesh=mesh,
810+
dtype=dtype,
811+
weights_dtype=weights_dtype,
812+
precision=precision,
813+
)
814+
)
815+
in_dim = out_dim
888816
if i != len(dim_mult) - 1:
889-
downsample_mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
890-
scale /= 2.0
891-
892-
self.down_blocks.append(
893-
WanDownBlock(
894-
in_dim=in_dim,
895-
out_dim=out_dim,
896-
num_res_blocks=num_res_blocks,
897-
dropout=dropout,
898-
downsample_mode=downsample_mode,
899-
add_attention=add_attention,
900-
rngs=rngs,
901-
mesh=mesh,
902-
dtype=dtype,
903-
weights_dtype=weights_dtype,
904-
precision=precision,
817+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
818+
self.down_blocks.append(
819+
WanResample(
820+
out_dim,
821+
mode=mode,
822+
rngs=rngs,
823+
mesh=mesh,
824+
dtype=dtype,
825+
weights_dtype=weights_dtype,
826+
precision=precision,
827+
)
905828
)
906-
)
829+
scale /= 2.0
907830

908831
self.mid_block = WanMidBlock(
909832
dim=out_dim,
@@ -940,14 +863,19 @@ def init_cache(self, batch_size, height, width, dtype):
940863
cache["down_blocks"] = []
941864

942865
h_curr, w_curr = height, width
943-
for block in self.down_blocks:
944-
cache["down_blocks"].append(
945-
block.initialize_cache(batch_size, h_curr, w_curr, dtype)
946-
)
947-
# Update dimensions if downsampling
948-
if block.layers and isinstance(block.layers[-1], WanResample):
949-
if block.layers[-1].mode in ["downsample2d", "downsample3d"]:
866+
for layer in self.down_blocks:
867+
if isinstance(layer, WanResidualBlock):
868+
cache["down_blocks"].append(
869+
layer.initialize_cache(batch_size, h_curr, w_curr, dtype)
870+
)
871+
elif isinstance(layer, WanResample):
872+
cache["down_blocks"].append(
873+
layer.initialize_cache(batch_size, h_curr, w_curr, dtype)
874+
)
875+
if layer.mode in ["downsample2d", "downsample3d"]:
950876
h_curr, w_curr = h_curr // 2, w_curr // 2
877+
else: # Attention
878+
cache["down_blocks"].append(None)
951879

952880
cache["mid_block"] = self.mid_block.initialize_cache(
953881
batch_size, h_curr, w_curr, dtype
@@ -969,9 +897,13 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
969897
new_cache["down_blocks"] = []
970898
current_down_caches = cache.get("down_blocks", [None] * len(self.down_blocks))
971899

972-
for i, block in enumerate(self.down_blocks):
973-
x, c = block(x, current_down_caches[i])
974-
new_cache["down_blocks"].append(c)
900+
for i, layer in enumerate(self.down_blocks):
901+
if isinstance(layer, (WanResidualBlock, WanResample)):
902+
x, c = layer(x, current_down_caches[i])
903+
new_cache["down_blocks"].append(c)
904+
else: # Attention
905+
x = layer(x)
906+
new_cache["down_blocks"].append(None)
975907

976908
x, c = self.mid_block(x, cache.get("mid_block"))
977909
new_cache["mid_block"] = c

0 commit comments

Comments
 (0)