Skip to content

Commit 716bf6c

Browse files
committed
Refactor
1 parent 295ada0 commit 716bf6c

1 file changed

Lines changed: 25 additions & 1 deletion

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,30 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
940940
return x, feat_cache, feat_idx
941941

942942

943+
class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution):
944+
pass
945+
946+
def _wan_diag_gauss_dist_flatten(dist):
947+
return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,)
948+
949+
def _wan_diag_gauss_dist_unflatten(aux, children):
950+
mean, logvar, std, var = children
951+
deterministic = aux[0]
952+
obj = WanDiagonalGaussianDistribution.__new__(WanDiagonalGaussianDistribution)
953+
obj.mean = mean
954+
obj.logvar = logvar
955+
obj.std = std
956+
obj.var = var
957+
obj.deterministic = deterministic
958+
return obj
959+
960+
tree_util.register_pytree_node(
961+
WanDiagonalGaussianDistribution,
962+
_wan_diag_gauss_dist_flatten,
963+
_wan_diag_gauss_dist_unflatten,
964+
)
965+
966+
943967
class AutoencoderKLWanCache:
944968

945969
def __init__(self, module):
@@ -1133,7 +1157,7 @@ def encode(
11331157
) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
11341158
"""Encode video into latent distribution."""
11351159
h = self._encode(x, feat_cache)
1136-
posterior = FlaxDiagonalGaussianDistribution(h)
1160+
posterior = WanDiagonalGaussianDistribution(h)
11371161
if not return_dict:
11381162
return (posterior,)
11391163
return FlaxAutoencoderKLOutput(latent_dist=posterior)

0 commit comments

Comments
 (0)