Skip to content

Commit 78ab13d

Browse files
committed
refactor
1 parent 7e66fd6 commit 78ab13d

2 files changed

Lines changed: 30 additions & 34 deletions

File tree

src/maxdiffusion/models/vae_flax.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import flax
2222
import flax.linen as nn
2323
import jax
24+
from jax import tree_util
2425
import jax.numpy as jnp
2526
from flax.core.frozen_dict import FrozenDict
2627

@@ -930,3 +931,29 @@ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, r
930931
return (sample,)
931932

932933
return FlaxDecoderOutput(sample=sample)
934+
935+
class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution):
936+
pass
937+
938+
939+
def _wan_diag_gauss_dist_flatten(dist):
940+
return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,)
941+
942+
943+
def _wan_diag_gauss_dist_unflatten(aux, children):
944+
mean, logvar, std, var = children
945+
deterministic = aux[0]
946+
obj = WanDiagonalGaussianDistribution.__new__(WanDiagonalGaussianDistribution)
947+
obj.mean = mean
948+
obj.logvar = logvar
949+
obj.std = std
950+
obj.var = var
951+
obj.deterministic = deterministic
952+
return obj
953+
954+
955+
tree_util.register_pytree_node(
956+
WanDiagonalGaussianDistribution,
957+
_wan_diag_gauss_dist_flatten,
958+
_wan_diag_gauss_dist_unflatten,
959+
)

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...configuration_utils import ConfigMixin
2525
from ..modeling_flax_utils import FlaxModelMixin, get_activation
2626
from ... import common_types
27-
from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput)
27+
from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput, WanDiagonalGaussianDistribution)
2828

2929
BlockSizes = common_types.BlockSizes
3030

@@ -645,16 +645,10 @@ def __init__(
645645

646646
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
647647
for resnet in self.resnets:
648-
if feat_cache is not None:
649-
x, feat_cache, feat_idx = resnet(x, feat_cache, feat_idx)
650-
else:
651-
x, _, _ = resnet(x)
648+
x, feat_cache, feat_idx = resnet(x, feat_cache, feat_idx)
652649

653650
if self.upsamplers is not None:
654-
if feat_cache is not None:
655-
x, feat_cache, feat_idx = self.upsamplers[0](x, feat_cache, feat_idx)
656-
else:
657-
x, _, _ = self.upsamplers[0](x)
651+
x, feat_cache, feat_idx = self.upsamplers[0](x, feat_cache, feat_idx)
658652
return x, feat_cache, feat_idx
659653

660654

@@ -950,31 +944,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
950944
return x, feat_cache, jnp.array(feat_idx, dtype=jnp.int32)
951945

952946

953-
class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution):
954-
pass
955-
956-
957-
def _wan_diag_gauss_dist_flatten(dist):
958-
return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,)
959-
960-
961-
def _wan_diag_gauss_dist_unflatten(aux, children):
962-
mean, logvar, std, var = children
963-
deterministic = aux[0]
964-
obj = WanDiagonalGaussianDistribution.__new__(WanDiagonalGaussianDistribution)
965-
obj.mean = mean
966-
obj.logvar = logvar
967-
obj.std = std
968-
obj.var = var
969-
obj.deterministic = deterministic
970-
return obj
971-
972-
973-
tree_util.register_pytree_node(
974-
WanDiagonalGaussianDistribution,
975-
_wan_diag_gauss_dist_flatten,
976-
_wan_diag_gauss_dist_unflatten,
977-
)
978947

979948

980949
class AutoencoderKLWanCache:

0 commit comments

Comments
 (0)