Skip to content

Commit 7e66fd6

Browse files
committed
reformatted
1 parent 6338698 commit 7e66fd6

1 file changed

Lines changed: 17 additions & 7 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@
3838
from jax.sharding import PartitionSpec
3939
from jax.lax import with_sharding_constraint
4040

41+
4142
def _update_cache(cache, idx, value):
4243
if cache is None:
4344
return None
44-
return cache[:idx] + (value,) + cache[idx+1:]
45+
return cache[:idx] + (value,) + cache[idx + 1 :]
46+
4547

4648
# Helper to ensure kernel_size, stride, padding are tuples of 3 integers
4749
def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]:
@@ -55,11 +57,14 @@ def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> T
5557

5658

5759
class RepSentinel:
60+
5861
def __eq__(self, other):
5962
return isinstance(other, RepSentinel)
6063

64+
6165
tree_util.register_pytree_node(RepSentinel, lambda x: ((), None), lambda _, __: RepSentinel())
6266

67+
6368
class WanCausalConv3d(nnx.Module):
6469

6570
def __init__(
@@ -503,7 +508,6 @@ def __init__(
503508
)
504509

505510
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
506-
507511
identity = x
508512
batch_size, time, height, width, channels = x.shape
509513

@@ -949,9 +953,11 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
949953
class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution):
950954
pass
951955

956+
952957
def _wan_diag_gauss_dist_flatten(dist):
953958
return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,)
954959

960+
955961
def _wan_diag_gauss_dist_unflatten(aux, children):
956962
mean, logvar, std, var = children
957963
deterministic = aux[0]
@@ -963,6 +969,7 @@ def _wan_diag_gauss_dist_unflatten(aux, children):
963969
obj.deterministic = deterministic
964970
return obj
965971

972+
966973
tree_util.register_pytree_node(
967974
WanDiagonalGaussianDistribution,
968975
_wan_diag_gauss_dist_flatten,
@@ -993,9 +1000,11 @@ def init_cache(self):
9931000
# cache encode
9941001
self._enc_feat_map = (None,) * self._enc_conv_num
9951002

1003+
9961004
def _wan_cache_flatten(cache):
9971005
return (cache._feat_map, cache._enc_feat_map), (cache._conv_num, cache._enc_conv_num)
9981006

1007+
9991008
def _wan_cache_unflatten(aux, children):
10001009
conv_num, enc_conv_num = aux
10011010
feat_map, enc_feat_map = children
@@ -1009,9 +1018,10 @@ def _wan_cache_unflatten(aux, children):
10091018
obj._enc_conv_num = enc_conv_num
10101019
obj._feat_map = feat_map
10111020
obj._enc_feat_map = enc_feat_map
1012-
obj.module = None # module is not needed inside the trace for the cache logic now
1021+
obj.module = None # module is not needed inside the trace for the cache logic now
10131022
return obj
10141023

1024+
10151025
tree_util.register_pytree_node(AutoencoderKLWanCache, _wan_cache_flatten, _wan_cache_unflatten)
10161026

10171027

@@ -1147,10 +1157,10 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
11471157
feat_idx=enc_conv_idx,
11481158
)
11491159
out = jnp.concatenate([out, out_], axis=1)
1150-
1160+
11511161
# Update back to the wrapper object if needed, but for result we use local vars
11521162
feat_cache._enc_feat_map = enc_feat_map
1153-
1163+
11541164
enc = self.quant_conv(out)
11551165
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
11561166
enc = jnp.concatenate([mu, logvar], axis=-1)
@@ -1173,7 +1183,7 @@ def _decode(
11731183
feat_cache.init_cache()
11741184
iter_ = z.shape[1]
11751185
x = self.post_quant_conv(z)
1176-
1186+
11771187
dec_feat_map = feat_cache._feat_map
11781188

11791189
for i in range(iter_):
@@ -1199,7 +1209,7 @@ def _decode(
11991209
fm3 = jnp.expand_dims(fm3, axis=axis)
12001210
fm4 = jnp.expand_dims(fm4, axis=axis)
12011211
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
1202-
1212+
12031213
feat_cache._feat_map = dec_feat_map
12041214

12051215
out = jnp.clip(out, min=-1.0, max=1.0)

0 commit comments

Comments
 (0)