@@ -940,6 +940,30 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
940940 return x , feat_cache , feat_idx
941941
942942
943+ class WanDiagonalGaussianDistribution (FlaxDiagonalGaussianDistribution ):
944+ pass
945+
946+ def _wan_diag_gauss_dist_flatten (dist ):
947+ return (dist .mean , dist .logvar , dist .std , dist .var ), (dist .deterministic ,)
948+
949+ def _wan_diag_gauss_dist_unflatten (aux , children ):
950+ mean , logvar , std , var = children
951+ deterministic = aux [0 ]
952+ obj = WanDiagonalGaussianDistribution .__new__ (WanDiagonalGaussianDistribution )
953+ obj .mean = mean
954+ obj .logvar = logvar
955+ obj .std = std
956+ obj .var = var
957+ obj .deterministic = deterministic
958+ return obj
959+
960+ tree_util .register_pytree_node (
961+ WanDiagonalGaussianDistribution ,
962+ _wan_diag_gauss_dist_flatten ,
963+ _wan_diag_gauss_dist_unflatten ,
964+ )
965+
966+
943967class AutoencoderKLWanCache :
944968
945969 def __init__ (self , module ):
@@ -1133,7 +1157,7 @@ def encode(
11331157 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
11341158 """Encode video into latent distribution."""
11351159 h = self ._encode (x , feat_cache )
1136- posterior = FlaxDiagonalGaussianDistribution (h )
1160+ posterior = WanDiagonalGaussianDistribution (h )
11371161 if not return_dict :
11381162 return (posterior ,)
11391163 return FlaxAutoencoderKLOutput (latent_dist = posterior )
0 commit comments