Skip to content

Commit 8e313fc

Browse files
committed
Cache added
1 parent 10dcf3f commit 8e313fc

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,30 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
10261026

10271027
return x, new_cache
10281028

1029+
class AutoencoderKLWanCache:
1030+
1031+
def __init__(self, module):
1032+
self.module = module
1033+
self.clear_cache()
1034+
1035+
def clear_cache(self):
1036+
"""Resets cache dictionaries and indices"""
1037+
1038+
def _count_conv3d(module):
1039+
count = 0
1040+
node_types = nnx.graph.iter_graph([module])
1041+
for _, value in node_types:
1042+
if isinstance(value, WanCausalConv3d):
1043+
count += 1
1044+
return count
1045+
1046+
self._conv_num = _count_conv3d(self.module.decoder)
1047+
self._conv_idx = [0]
1048+
self._feat_map = [None] * self._conv_num
1049+
# cache encode
1050+
self._enc_conv_num = _count_conv3d(self.module.encoder)
1051+
self._enc_conv_idx = [0]
1052+
self._enc_feat_map = [None] * self._enc_conv_num
10291053

10301054
class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin):
10311055
def __init__(

0 commit comments

Comments
 (0)