|
24 | 24 | from ...configuration_utils import ConfigMixin |
25 | 25 | from ..modeling_flax_utils import FlaxModelMixin, get_activation |
26 | 26 | from ... import common_types |
27 | | -from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) |
| 27 | +from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput, WanDiagonalGaussianDistribution) |
28 | 28 |
|
29 | 29 | BlockSizes = common_types.BlockSizes |
30 | 30 |
|
@@ -645,16 +645,10 @@ def __init__( |
645 | 645 |
|
646 | 646 | def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): |
647 | 647 | for resnet in self.resnets: |
648 | | - if feat_cache is not None: |
649 | | - x, feat_cache, feat_idx = resnet(x, feat_cache, feat_idx) |
650 | | - else: |
651 | | - x, _, _ = resnet(x) |
| 648 | + x, feat_cache, feat_idx = resnet(x, feat_cache, feat_idx) |
652 | 649 |
|
653 | 650 | if self.upsamplers is not None: |
654 | | - if feat_cache is not None: |
655 | | - x, feat_cache, feat_idx = self.upsamplers[0](x, feat_cache, feat_idx) |
656 | | - else: |
657 | | - x, _, _ = self.upsamplers[0](x) |
| 651 | + x, feat_cache, feat_idx = self.upsamplers[0](x, feat_cache, feat_idx) |
658 | 652 | return x, feat_cache, feat_idx |
659 | 653 |
|
660 | 654 |
|
@@ -950,31 +944,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): |
950 | 944 | return x, feat_cache, jnp.array(feat_idx, dtype=jnp.int32) |
951 | 945 |
|
952 | 946 |
|
953 | | -class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution): |
954 | | - pass |
955 | | - |
956 | | - |
957 | | -def _wan_diag_gauss_dist_flatten(dist): |
958 | | - return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,) |
959 | | - |
960 | | - |
961 | | -def _wan_diag_gauss_dist_unflatten(aux, children): |
962 | | - mean, logvar, std, var = children |
963 | | - deterministic = aux[0] |
964 | | - obj = WanDiagonalGaussianDistribution.__new__(WanDiagonalGaussianDistribution) |
965 | | - obj.mean = mean |
966 | | - obj.logvar = logvar |
967 | | - obj.std = std |
968 | | - obj.var = var |
969 | | - obj.deterministic = deterministic |
970 | | - return obj |
971 | | - |
972 | | - |
973 | | -tree_util.register_pytree_node( |
974 | | - WanDiagonalGaussianDistribution, |
975 | | - _wan_diag_gauss_dist_flatten, |
976 | | - _wan_diag_gauss_dist_unflatten, |
977 | | -) |
978 | 947 |
|
979 | 948 |
|
980 | 949 | class AutoencoderKLWanCache: |
|
0 commit comments