@@ -289,6 +289,14 @@ def __init__(
289289 kernel_size = (1 , 3 , 3 ),
290290 stride = (1 , 2 , 2 )
291291 )
292+ self .time_conv = WanCausalConv3d (
293+ rngs = rngs ,
294+ in_channels = dim ,
295+ out_channels = dim ,
296+ kernel_size = (3 , 1 , 1 ),
297+ stride = (2 , 1 , 1 ),
298+ padding = (0 , 0 , 0 )
299+ )
292300 else :
293301 self .resample = Identity ()
294302
@@ -302,6 +310,18 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
302310 h_new , w_new , c_new = x .shape [1 :]
303311 x = x .reshape (n , d , h_new , w_new , c_new )
304312
313+ if self .mode == "downsample3d" :
314+ if feat_cache is not None :
315+ idx = feat_idx [0 ]
316+ if feat_cache [idx ] is None :
317+ feat_cache [idx ] = jnp .copy (x )
318+ feat_idx [0 ] += 1
319+ else :
320+ cache_x = jnp .copy (x [:, - 1 :, :, :, :])
321+ x = self .time_conv (jnp .concatenate ([feat_cache [idx ][:, - 1 :, :, :, :], x ], axis = 1 ))
322+ feat_cache [idx ] = cache_x
323+ feat_idx [0 ] += 1
324+
305325 return x
306326
307327class WanResidualBlock (nnx .Module ):
@@ -343,7 +363,6 @@ def __init__(
343363
344364 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
345365 # Apply shortcut connection
346- #breakpoint()
347366 h = self .conv_shortcut (x )
348367
349368 x = self .norm1 (x )
@@ -505,7 +524,8 @@ def __init__(
505524 if i != len (dim_mult ) - 1 :
506525 mode = "downsample3d" if temperal_downsample [i ] else "downsample2d"
507526 self .down_blocks .append (WanResample (out_dim , mode = mode , rngs = rngs ))
508-
527+ scale /= 2.0
528+
509529 # middle_blocks
510530 self .mid_block = WanMidBlock (
511531 dim = out_dim ,
@@ -516,7 +536,12 @@ def __init__(
516536 )
517537
518538 # output blocks
519- self .norm_out = WanRMS_norm (out_dim , images = False , rngs = rngs )
539+ self .norm_out = WanRMS_norm (
540+ out_dim ,
541+ channel_first = False ,
542+ images = False ,
543+ rngs = rngs
544+ )
520545 self .conv_out = WanCausalConv3d (
521546 rngs = rngs ,
522547 in_channels = out_dim ,
@@ -526,14 +551,39 @@ def __init__(
526551 )
527552
528553 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
529- # (1, 1, 480, 720, 3)
530- x = self .conv_in (x )
554+ if feat_cache is not None :
555+ idx = feat_idx [0 ]
556+ cache_x = jnp .copy (x [:, - CACHE_T :, :, :])
557+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
558+ # cache last frame of the last two chunk
559+ cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
560+ x = self .conv_in (x , feat_cache [idx ])
561+ feat_cache [idx ] = cache_x
562+ feat_idx [0 ] += 1
563+ else :
564+ x = self .conv_in (x )
531565 # (1, 1, 480, 720, 96)
532566 for layer in self .down_blocks :
533- x = layer (x )
567+ if feat_cache is not None :
568+ x = layer (x , feat_cache , feat_idx )
569+ else :
570+ x = layer (x )
534571
535- x = self .mid_block (x )
536- breakpoint ()
572+ x = self .mid_block (x , feat_cache , feat_idx )
573+
574+ x = self .norm_out (x )
575+ x = self .nonlinearity (x )
576+ if feat_cache is not None :
577+ idx = feat_idx [0 ]
578+ cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
579+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
580+ # cache last frame of last two chunk
581+ cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
582+ x = self .conv_out (x , feat_cache [idx ])
583+ feat_cache [idx ] = cache_x
584+ feat_idx [0 ] += 1
585+ else :
586+ x = self .conv_out (x )
537587 return x
538588
539589class WanDecoder3d (nnx .Module ):
@@ -626,9 +676,7 @@ def __init__(
626676 )
627677
628678 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
629- breakpoint ()
630679 x = self .conv_in (x )
631- breakpoint ()
632680 return x
633681
634682class AutoencoderKLWan (nnx .Module , FlaxModelMixin , ConfigMixin ):
@@ -696,9 +744,7 @@ def _count_conv3d(module):
696744 count = 0
697745 node_types = nnx .graph .iter_graph ([module ])
698746 for path , value in node_types :
699- #breakpoint()
700747 if isinstance (value , WanCausalConv3d ):
701- print ("value: " , value )
702748 count += 1
703749 return count
704750
@@ -711,6 +757,7 @@ def _count_conv3d(module):
711757 self ._enc_feat_map = [None ] * self ._enc_conv_num
712758
713759 def _encode (self , x : jax .Array ):
760+ self .clear_cache ()
714761 if x .shape [- 1 ] != 3 :
715762 # reshape channel last for JAX
716763 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
@@ -721,6 +768,7 @@ def _encode(self, x: jax.Array):
721768 t = x .shape [1 ]
722769 iter_ = 1 + (t - 1 ) // 4
723770 for i in range (iter_ ):
771+ self ._enc_conv_idx = [0 ]
724772 if i == 0 :
725773 out = self .encoder (
726774 x [:, :1 , :, :, :],
@@ -729,24 +777,23 @@ def _encode(self, x: jax.Array):
729777 )
730778 else :
731779 out_ = self .encoder (
732- x [:, 1 + 4 * (i - 1 ) : 1 + 4 * i , :, :],
780+ x [:, 1 + 4 * (i - 1 ) : 1 + 4 * i , :, :, : ],
733781 feat_cache = self ._enc_feat_map ,
734782 feat_idx = self ._enc_conv_idx
735783 )
736784 out = jnp .concatenate ([out , out_ ], axis = 1 )
737-
738- # enc = self.quant_conv(out)
739- # mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
740- # enc = jnp.concatenate([mu, logvar], dim=1)
741- # self.clear_cache()
785+ enc = self .quant_conv (out )
786+ mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
787+ enc = jnp .concatenate ([mu , logvar ], axis = - 1 )
788+ self .clear_cache ()
742789 # return enc
743- return x
790+ return enc
744791
745792 def encode (self , x : jax .Array , return_dict : bool = True ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
746793 """ Encode video into latent distribution."""
747794 h = self ._encode (x )
748795 posterior = FlaxDiagonalGaussianDistribution (h )
749796 if not return_dict :
750797 return (posterior , )
751- return FlaxAutoencoderKLOutput (latent_dict = posterior )
798+ return FlaxAutoencoderKLOutput (latent_dist = posterior )
752799
0 commit comments