1919import jax
2020import jax .numpy as jnp
2121from flax import nnx
22- from ...configuration_utils import ConfigMixin , flax_register_to_config
22+ from ...configuration_utils import ConfigMixin
2323from ..modeling_flax_utils import FlaxModelMixin
2424from ... import common_types
2525from ..vae_flax import (FlaxAutoencoderKLOutput , FlaxDiagonalGaussianDistribution , FlaxDecoderOutput )
26- import numpy as np
2726BlockSizes = common_types .BlockSizes
2827
2928CACHE_T = 2
@@ -60,13 +59,6 @@ def __init__(
6059 stride : Union [int , Tuple [int , int , int ]] = 1 ,
6160 padding : Union [int , Tuple [int , int , int ]] = 0 ,
6261 use_bias : bool = True ,
63- flash_min_seq_length : int = 4096 ,
64- flash_block_sizes : BlockSizes = None ,
65- mesh : jax .sharding .Mesh = None ,
66- dtype : jnp .dtype = jnp .float32 ,
67- weights_dtype : jnp .dtype = jnp .float32 ,
68- precision : jax .lax .Precision = None ,
69- attention : str = "dot_product" ,
7062 ):
7163 self .kernel_size = _canonicalize_tuple (kernel_size , 3 , "kernel_size" )
7264 self .stride = _canonicalize_tuple (stride , 3 , "stride" )
@@ -191,13 +183,6 @@ def __init__(
191183 rngs : nnx .Rngs ,
192184 kernel_size : Union [int , Tuple [int , int , int ]],
193185 stride : Union [int , Tuple [int , int , int ]] = 1 ,
194- flash_min_seq_length : int = 4096 ,
195- flash_block_sizes : BlockSizes = None ,
196- mesh : jax .sharding .Mesh = None ,
197- dtype : jnp .dtype = jnp .float32 ,
198- weights_dtype : jnp .dtype = jnp .float32 ,
199- precision : jax .lax .Precision = None ,
200- attention : str = "dot_product" ,
201186 ):
202187 self .conv = nnx .Conv (dim , dim , kernel_size = kernel_size , strides = stride , use_bias = True , rngs = rngs )
203188
@@ -212,13 +197,6 @@ def __init__(
212197 dim : int ,
213198 mode : str ,
214199 rngs : nnx .Rngs ,
215- flash_min_seq_length : int = 4096 ,
216- flash_block_sizes : BlockSizes = None ,
217- mesh : jax .sharding .Mesh = None ,
218- dtype : jnp .dtype = jnp .float32 ,
219- weights_dtype : jnp .dtype = jnp .float32 ,
220- precision : jax .lax .Precision = None ,
221- attention : str = "dot_product" ,
222200 ):
223201 self .dim = dim
224202 self .mode = mode
@@ -548,7 +526,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
548526 feat_idx [0 ] += 1
549527 else :
550528 x = self .conv_in (x )
551- # (1, 1, 480, 720, 96)
552529 for layer in self .down_blocks :
553530 if feat_cache is not None :
554531 x = layer (x , feat_cache , feat_idx )
0 commit comments