Skip to content

Commit efe8528

Browse files
add cache logic to modules.
1 parent a5e1e95 commit efe8528

1 file changed

Lines changed: 24 additions & 3 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,33 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
367367

368368
x = self.norm1(x)
369369
x = self.nonlinearity(x)
370-
x = self.conv1(x)
370+
371+
if feat_cache is not None:
372+
idx = feat_idx[0]
373+
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
374+
if cache_x.shape[1] <2 and feat_cache[idx] is not None:
375+
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
376+
377+
x = self.conv1(x, feat_cache[idx])
378+
feat_cache[idx] = cache_x
379+
feat_idx[0] +=1
380+
else:
381+
x = self.conv1(x)
371382

372383
x = self.norm2(x)
373384
x = self.nonlinearity(x)
374385
x = self.dropout(x)
375-
x = self.conv2(x)
386+
387+
if feat_cache is not None:
388+
idx = feat_idx[0]
389+
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
390+
if cache_x.shape[1] <2 and feat_cache[idx] is not None:
391+
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
392+
x = self.conv2(x, feat_cache[idx])
393+
feat_cache[idx] = cache_x
394+
feat_idx[0] +=1
395+
else:
396+
x = self.conv2(x)
376397

377398
return x + h
378399

@@ -442,7 +463,7 @@ def __init__(
442463
self.resnets = resnets
443464

444465
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
445-
x = self.resnets[0](x)
466+
x = self.resnets[0](x, feat_cache, feat_idx)
446467
for attn, resnet in zip(self.attentions, self.resnets[1:]):
447468
if attn is not None:
448469
x = attn(x)

0 commit comments

Comments
 (0)