Skip to content

Commit 4245b24

Browse files
cleanup unused code.
1 parent d9749e9 commit 4245b24

1 file changed

Lines changed: 1 addition & 24 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
import jax
2020
import jax.numpy as jnp
2121
from flax import nnx
22-
from ...configuration_utils import ConfigMixin, flax_register_to_config
22+
from ...configuration_utils import ConfigMixin
2323
from ..modeling_flax_utils import FlaxModelMixin
2424
from ... import common_types
2525
from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput)
26-
import numpy as np
2726
BlockSizes = common_types.BlockSizes
2827

2928
CACHE_T = 2
@@ -60,13 +59,6 @@ def __init__(
6059
stride: Union[int, Tuple[int, int, int]] = 1,
6160
padding: Union[int, Tuple[int, int, int]] = 0,
6261
use_bias: bool = True,
63-
flash_min_seq_length: int = 4096,
64-
flash_block_sizes: BlockSizes = None,
65-
mesh: jax.sharding.Mesh = None,
66-
dtype: jnp.dtype = jnp.float32,
67-
weights_dtype: jnp.dtype = jnp.float32,
68-
precision: jax.lax.Precision = None,
69-
attention: str = "dot_product",
7062
):
7163
self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size")
7264
self.stride = _canonicalize_tuple(stride, 3, "stride")
@@ -191,13 +183,6 @@ def __init__(
191183
rngs: nnx.Rngs,
192184
kernel_size: Union[int, Tuple[int, int, int]],
193185
stride: Union[int, Tuple[int, int, int]] = 1,
194-
flash_min_seq_length: int = 4096,
195-
flash_block_sizes: BlockSizes = None,
196-
mesh: jax.sharding.Mesh = None,
197-
dtype: jnp.dtype = jnp.float32,
198-
weights_dtype: jnp.dtype = jnp.float32,
199-
precision: jax.lax.Precision = None,
200-
attention: str = "dot_product",
201186
):
202187
self.conv = nnx.Conv(dim, dim, kernel_size=kernel_size, strides=stride, use_bias=True, rngs=rngs)
203188

@@ -212,13 +197,6 @@ def __init__(
212197
dim: int,
213198
mode: str,
214199
rngs: nnx.Rngs,
215-
flash_min_seq_length: int = 4096,
216-
flash_block_sizes: BlockSizes = None,
217-
mesh: jax.sharding.Mesh = None,
218-
dtype: jnp.dtype = jnp.float32,
219-
weights_dtype: jnp.dtype = jnp.float32,
220-
precision: jax.lax.Precision = None,
221-
attention: str = "dot_product",
222200
):
223201
self.dim = dim
224202
self.mode = mode
@@ -548,7 +526,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
548526
feat_idx[0] += 1
549527
else:
550528
x = self.conv_in(x)
551-
# (1, 1, 480, 720, 96)
552529
for layer in self.down_blocks:
553530
if feat_cache is not None:
554531
x = layer(x, feat_cache, feat_idx)

0 commit comments

Comments
 (0)