@@ -672,101 +672,6 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
672672 return x , new_cache
673673
674674
675- class WanDownBlock (nnx .Module ):
676- def __init__ (
677- self ,
678- in_dim : int ,
679- out_dim : int ,
680- num_res_blocks : int ,
681- rngs : nnx .Rngs ,
682- dropout : float = 0.0 ,
683- downsample_mode : Optional [str ] = None ,
684- add_attention : bool = False ,
685- non_linearity : str = "silu" ,
686- mesh : jax .sharding .Mesh = None ,
687- dtype : jnp .dtype = jnp .float32 ,
688- weights_dtype : jnp .dtype = jnp .float32 ,
689- precision : jax .lax .Precision = None ,
690- ):
691- self .layers = nnx .List ([])
692- current_dim = in_dim
693- for _ in range (num_res_blocks ):
694- self .layers .append (
695- WanResidualBlock (
696- in_dim = current_dim ,
697- out_dim = out_dim ,
698- dropout = dropout ,
699- non_linearity = non_linearity ,
700- rngs = rngs ,
701- mesh = mesh ,
702- dtype = dtype ,
703- weights_dtype = weights_dtype ,
704- precision = precision ,
705- )
706- )
707- if add_attention :
708- self .layers .append (
709- WanAttentionBlock (
710- dim = out_dim ,
711- rngs = rngs ,
712- mesh = mesh ,
713- dtype = dtype ,
714- weights_dtype = weights_dtype ,
715- precision = precision ,
716- )
717- )
718- current_dim = out_dim
719-
720- if downsample_mode is not None :
721- self .layers .append (
722- WanResample (
723- out_dim ,
724- mode = downsample_mode ,
725- rngs = rngs ,
726- mesh = mesh ,
727- dtype = dtype ,
728- weights_dtype = weights_dtype ,
729- precision = precision ,
730- )
731- )
732-
733- def initialize_cache (self , batch_size , height , width , dtype ):
734- """Initialize cache for all layers."""
735- cache = {"layers" : []}
736- h_curr , w_curr = height , width
737- for layer in self .layers :
738- if isinstance (layer , WanResidualBlock ):
739- cache ["layers" ].append (
740- layer .initialize_cache (batch_size , h_curr , w_curr , dtype )
741- )
742- elif isinstance (layer , WanResample ):
743- cache ["layers" ].append (
744- layer .initialize_cache (batch_size , h_curr , w_curr , dtype )
745- )
746- if layer .mode in ["downsample2d" , "downsample3d" ]:
747- h_curr , w_curr = h_curr // 2 , w_curr // 2
748- else : # Attention
749- cache ["layers" ].append (None )
750- return cache
751-
752- def __call__ (self , x : jax .Array , cache : Dict [str , Any ] = None ):
753- """Pure function: returns (output, new_cache)."""
754- if cache is None :
755- cache = {}
756- new_cache = {"layers" : []}
757-
758- current_caches = cache .get ("layers" , [None ] * len (self .layers ))
759- for i , layer in enumerate (self .layers ):
760- if isinstance (layer , (WanResidualBlock , WanResample )):
761- x , c = layer (x , current_caches [i ])
762- new_cache ["layers" ].append (c )
763- else : # Attention
764- x = layer (x )
765- new_cache ["layers" ].append (None )
766-
767- return x , new_cache
768-
769-
770675class WanUpBlock (nnx .Module ):
771676 def __init__ (
772677 self ,
@@ -883,27 +788,45 @@ def __init__(
883788
884789 self .down_blocks = nnx .List ([])
885790 for i , (in_dim , out_dim ) in enumerate (zip (dims [:- 1 ], dims [1 :])):
886- add_attention = scale in attn_scales
887- downsample_mode = None
791+ for _ in range (num_res_blocks ):
792+ self .down_blocks .append (
793+ WanResidualBlock (
794+ in_dim = in_dim ,
795+ out_dim = out_dim ,
796+ dropout = dropout ,
797+ rngs = rngs ,
798+ mesh = mesh ,
799+ dtype = dtype ,
800+ weights_dtype = weights_dtype ,
801+ precision = precision ,
802+ )
803+ )
804+ if scale in attn_scales :
805+ self .down_blocks .append (
806+ WanAttentionBlock (
807+ dim = out_dim ,
808+ rngs = rngs ,
809+ mesh = mesh ,
810+ dtype = dtype ,
811+ weights_dtype = weights_dtype ,
812+ precision = precision ,
813+ )
814+ )
815+ in_dim = out_dim
888816 if i != len (dim_mult ) - 1 :
889- downsample_mode = "downsample3d" if temperal_downsample [i ] else "downsample2d"
890- scale /= 2.0
891-
892- self .down_blocks .append (
893- WanDownBlock (
894- in_dim = in_dim ,
895- out_dim = out_dim ,
896- num_res_blocks = num_res_blocks ,
897- dropout = dropout ,
898- downsample_mode = downsample_mode ,
899- add_attention = add_attention ,
900- rngs = rngs ,
901- mesh = mesh ,
902- dtype = dtype ,
903- weights_dtype = weights_dtype ,
904- precision = precision ,
817+ mode = "downsample3d" if temperal_downsample [i ] else "downsample2d"
818+ self .down_blocks .append (
819+ WanResample (
820+ out_dim ,
821+ mode = mode ,
822+ rngs = rngs ,
823+ mesh = mesh ,
824+ dtype = dtype ,
825+ weights_dtype = weights_dtype ,
826+ precision = precision ,
827+ )
905828 )
906- )
829+ scale /= 2.0
907830
908831 self .mid_block = WanMidBlock (
909832 dim = out_dim ,
@@ -940,14 +863,19 @@ def init_cache(self, batch_size, height, width, dtype):
940863 cache ["down_blocks" ] = []
941864
942865 h_curr , w_curr = height , width
943- for block in self .down_blocks :
944- cache ["down_blocks" ].append (
945- block .initialize_cache (batch_size , h_curr , w_curr , dtype )
946- )
947- # Update dimensions if downsampling
948- if block .layers and isinstance (block .layers [- 1 ], WanResample ):
949- if block .layers [- 1 ].mode in ["downsample2d" , "downsample3d" ]:
866+ for layer in self .down_blocks :
867+ if isinstance (layer , WanResidualBlock ):
868+ cache ["down_blocks" ].append (
869+ layer .initialize_cache (batch_size , h_curr , w_curr , dtype )
870+ )
871+ elif isinstance (layer , WanResample ):
872+ cache ["down_blocks" ].append (
873+ layer .initialize_cache (batch_size , h_curr , w_curr , dtype )
874+ )
875+ if layer .mode in ["downsample2d" , "downsample3d" ]:
950876 h_curr , w_curr = h_curr // 2 , w_curr // 2
877+ else : # Attention
878+ cache ["down_blocks" ].append (None )
951879
952880 cache ["mid_block" ] = self .mid_block .initialize_cache (
953881 batch_size , h_curr , w_curr , dtype
@@ -969,9 +897,13 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
969897 new_cache ["down_blocks" ] = []
970898 current_down_caches = cache .get ("down_blocks" , [None ] * len (self .down_blocks ))
971899
972- for i , block in enumerate (self .down_blocks ):
973- x , c = block (x , current_down_caches [i ])
974- new_cache ["down_blocks" ].append (c )
900+ for i , layer in enumerate (self .down_blocks ):
901+ if isinstance (layer , (WanResidualBlock , WanResample )):
902+ x , c = layer (x , current_down_caches [i ])
903+ new_cache ["down_blocks" ].append (c )
904+ else : # Attention
905+ x = layer (x )
906+ new_cache ["down_blocks" ].append (None )
975907
976908 x , c = self .mid_block (x , cache .get ("mid_block" ))
977909 new_cache ["mid_block" ] = c
0 commit comments