Skip to content

Commit 00ae6fc

Browse files
committed
reformatted
1 parent 0bc8661 commit 00ae6fc

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

src/maxdiffusion/models/vae_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, r
932932

933933
return FlaxDecoderOutput(sample=sample)
934934

935+
935936
class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution):
936937
pass
937938

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
from ...configuration_utils import ConfigMixin
2525
from ..modeling_flax_utils import FlaxModelMixin, get_activation
2626
from ... import common_types
27-
from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput, WanDiagonalGaussianDistribution)
27+
from ..vae_flax import (
28+
FlaxAutoencoderKLOutput,
29+
FlaxDiagonalGaussianDistribution,
30+
FlaxDecoderOutput,
31+
WanDiagonalGaussianDistribution,
32+
)
2833

2934
BlockSizes = common_types.BlockSizes
3035

@@ -944,8 +949,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
944949
return x, feat_cache, jnp.array(feat_idx, dtype=jnp.int32)
945950

946951

947-
948-
949952
class AutoencoderKLWanCache:
950953

951954
def __init__(self, module):

0 commit comments

Comments
 (0)