2222from ...configuration_utils import ConfigMixin , flax_register_to_config
2323from ..modeling_flax_utils import FlaxModelMixin
2424from ... import common_types
25- from ..vae_flax import FlaxAutoencoderKLOutput , FlaxDiagonalGaussianDistribution
25+ from ..vae_flax import (
26+ FlaxAutoencoderKLOutput ,
27+ FlaxDiagonalGaussianDistribution ,
28+ FlaxDecoderOutput
29+ )
2630
2731BlockSizes = common_types .BlockSizes
2832
@@ -82,7 +86,7 @@ def __init__(
8286 (0 , 0 ) # Channel dimension - no padding
8387 )
8488
85- # Store the amount of padding needed *before* the depth dimension for caching logoic
89+ # Store the amount of padding needed *before* the depth dimension for caching logic
8690 self ._depth_padding_before = self ._causal_padding [1 ][0 ] # 2 * padding_tuple[0]
8791
8892 self .conv = nnx .Conv (
@@ -103,7 +107,6 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Arr
103107 # Ensure cache has same spatial/channel dims, potentially different depth
104108 assert cache_x .shape [0 ] == x .shape [0 ] and \
105109 cache_x .shape [2 :] == x .shape [2 :], "Cache spatial/channel dims mismatch"
106-
107110 cache_len = cache_x .shape [1 ]
108111 x = jnp .concatenate ([cache_x , x ], axis = 1 ) # Concat along depth (D)
109112
@@ -166,24 +169,13 @@ def __init__(self, scale_factor: Tuple[float, float], method: str = 'nearest'):
166169 def __call__ (self , x : jax .Array ) -> jax .Array :
167170 input_dtype = x .dtype
168171 in_shape = x .shape
169- is_3d = len (in_shape ) == 5
170- n , d , h , w , c = in_shape if is_3d else (in_shape [0 ], 1 , in_shape [1 ], in_shape [2 ], in_shape [3 ])
171-
172+ assert len (in_shape ) == 4 , "This module only takes tensors with shape of 4."
173+ n , h , w , c = in_shape
172174 target_h = int (h * self .scale_factor [0 ])
173175 target_w = int (w * self .scale_factor [1 ])
174-
175- # jax.image.resize expects (..., H, W, C)
176- if is_3d :
177- x_reshaped = x .reshape (n * d , h , w , c )
178- out_reshaped = jax .image .resize (x_reshaped .astype (jnp .float32 ),
179- (n * d , target_h , target_w , c ),
180- method = self .method )
181- out = out_reshaped .reshape (n , d , target_h , target_w , c )
182- else : # Asumming (N, H, W, C)
183- out = jax .image .resize (x .astype (jnp .float32 ),
184- (n , target_h , target_w , c ),
185- method = self .method )
186-
176+ out = jax .image .resize (x .astype (jnp .float32 ),
177+ (n , target_h , target_w , c ),
178+ method = self .method )
187179 return out .astype (input_dtype )
188180
189181class Identity (nnx .Module ):
@@ -256,7 +248,7 @@ def __init__(
256248 )
257249 elif mode == "upsample3d" :
258250 self .resample = nnx .Sequential (
259- WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" ),
251+ WanUpsample (scale_factor = (2.0 , 2.0 , 2.0 ), method = "nearest" ),
260252 nnx .Conv (
261253 dim ,
262254 dim // 2 ,
@@ -305,6 +297,29 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
305297 n , d , h , w , c = x .shape
306298 assert c == self .dim
307299
300+ if self .mode == "upsample3d" :
301+ if feat_cache is not None :
302+ idx = feat_idx [0 ]
303+ if feat_cache [idx ] is None :
304+ feat_cache [idx ] = "Rep"
305+ feat_idx [0 ] += 1
306+ else :
307+ cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
308+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] != "Rep" :
309+ # cache last frame of last two chunk
310+ cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
311+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] != "Rep" :
312+ cache_x = jnp .concatenate ([jnp .zeros (cache_x .shape ), cache_x ], dim = 1 )
313+ if feat_cache [idx ] == "Rep" :
314+ x = self .time_conv (x )
315+ else :
316+ x = self .time_conv (x , feat_cache [idx ])
317+ feat_cache [idx ] = cache_x
318+ feat_idx [0 ] += 1
319+ x = x .reshape (n , 2 , d , h , w , c )
320+ x = jnp .stack ([x [:, 0 , :, :, :, :], x [:, 1 , :, :, :, :]], axis = 2 )
321+ x = x .reshape (n , d * 2 , h , w , c )
322+ d = x .shape [1 ]
308323 x = x .reshape (n * d ,h ,w ,c )
309324 x = self .resample (x )
310325 h_new , w_new , c_new = x .shape [1 :]
@@ -371,7 +386,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
371386 if feat_cache is not None :
372387 idx = feat_idx [0 ]
373388 cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
374- if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
389+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
375390 cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
376391
377392 x = self .conv1 (x , feat_cache [idx ])
@@ -387,7 +402,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
387402 if feat_cache is not None :
388403 idx = feat_idx [0 ]
389404 cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
390- if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
405+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
391406 cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
392407 x = self .conv2 (x , feat_cache [idx ])
393408 feat_cache [idx ] = cache_x
@@ -458,7 +473,7 @@ def __init__(
458473 attentions = []
459474 for _ in range (num_layers ):
460475 attentions .append (WanAttentionBlock (dim = dim , rngs = rngs ))
461- resnets .append (WanResidualBlock (in_dim = dim , out_dim = dim , rngs = rngs ,dropout = dropout , non_linearity = non_linearity ))
476+ resnets .append (WanResidualBlock (in_dim = dim , out_dim = dim , rngs = rngs , dropout = dropout , non_linearity = non_linearity ))
462477 self .attentions = attentions
463478 self .resnets = resnets
464479
@@ -467,7 +482,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
467482 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
468483 if attn is not None :
469484 x = attn (x )
470- x = resnet (x )
485+ x = resnet (x , feat_cache , feat_idx )
471486 return x
472487
473488class WanUpBlock (nnx .Module ):
@@ -482,19 +497,31 @@ def __init__(
482497 non_linearity : str = "silu"
483498 ):
484499 # Create layers list
485- self . resnets = []
500+ resnets = []
486501 # Add residual blocks and attention if needed
487502 current_dim = in_dim
488503 for _ in range (num_res_blocks + 1 ):
489- self . resnets .append (WanResidualBlock (in_dim = current_dim , out_dim = out_dim , dropout = dropout , non_linearity = non_linearity , rngs = rngs ))
504+ resnets .append (WanResidualBlock (in_dim = current_dim , out_dim = out_dim , dropout = dropout , non_linearity = non_linearity , rngs = rngs ))
490505 current_dim = out_dim
506+ self .resnets = resnets
491507
492508 # Add upsampling layer if needed.
493509 self .upsamplers = None
494510 if upsample_mode is not None :
495- self .upsamplers = WanResample (dim = out_dim , mode = upsample_mode , rngs = rngs )
511+ self .upsamplers = [ WanResample (dim = out_dim , mode = upsample_mode , rngs = rngs )]
496512
497513 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
514+ for resnet in self .resnets :
515+ if feat_cache is not None :
516+ x = resnet (x , feat_cache , feat_idx )
517+ else :
518+ x = resnet (x )
519+
520+ if self .upsamplers is not None :
521+ if feat_cache is not None :
522+ x = self .upsamplers [0 ](x , feat_cache , feat_idx )
523+ else :
524+ x = self .upsamplers [0 ](x )
498525 return x
499526
500527class WanEncoder3d (nnx .Module ):
@@ -655,7 +682,13 @@ def __init__(
655682 )
656683
657684 # middle_blocks
658- self .mid_block = WanMidBlock (dim = dims [0 ], rngs = rngs , dropout = dropout , non_linearity = non_linearity , num_layers = 1 )
685+ self .mid_block = WanMidBlock (
686+ dim = dims [0 ],
687+ rngs = rngs ,
688+ dropout = dropout ,
689+ non_linearity = non_linearity ,
690+ num_layers = 1
691+ )
659692
660693 # upsample blocks
661694 self .up_blocks = []
@@ -668,7 +701,6 @@ def __init__(
668701 upsample_mode = None
669702 if i != len (dim_mult ) - 1 :
670703 upsample_mode = "upsample3d" if temperal_upsample [i ] else "upsample2d"
671-
672704 # Crete and add the upsampling block
673705 up_block = WanUpBlock (
674706 in_dim = in_dim ,
@@ -686,8 +718,7 @@ def __init__(
686718 scale *= 2.0
687719
688720 # output blocks
689- self .norm_out = nnx .RMSNorm (num_features = out_dim , )
690- self .norm_out = WanRMS_norm (dim = out_dim , images = False , rngs = rngs )
721+ self .norm_out = WanRMS_norm (dim = out_dim , images = False , rngs = rngs , channel_first = False )
691722 self .conv_out = WanCausalConv3d (
692723 rngs = rngs ,
693724 in_channels = out_dim ,
@@ -697,7 +728,39 @@ def __init__(
697728 )
698729
699730 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
700- x = self .conv_in (x )
731+ if feat_cache is not None :
732+ idx = feat_idx [0 ]
733+ cache_x = jnp .copy (x [:, - CACHE_T : , :, :, :])
734+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
735+ # cache last frame of the last two chunk
736+ cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
737+ x = self .conv_in (x , feat_cache [idx ])
738+ feat_cache [idx ] = cache_x
739+ feat_idx [0 ] += 1
740+ else :
741+ x = self .conv_in (x )
742+
743+ ## middle
744+ x = self .mid_block (x , feat_cache , feat_idx )
745+
746+ ## upsamples
747+ for up_block in self .up_blocks :
748+ x = up_block (x , feat_cache , feat_idx )
749+
750+ ## head
751+ x = self .norm_out (x )
752+ x = self .nonlinearity (x )
753+ if feat_cache is not None :
754+ idx = feat_idx [0 ]
755+ cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
756+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None :
757+ # cache last frame of the last two chunk
758+ cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
759+ x = self .conv_out (x , feat_cache [idx ])
760+ feat_cache [idx ] = cache_x
761+ feat_idx [0 ] += 1
762+ else :
763+ x = self .conv_out (x )
701764 return x
702765
703766class AutoencoderKLWan (nnx .Module , FlaxModelMixin , ConfigMixin ):
@@ -723,6 +786,8 @@ def __init__(
723786 self .z_dim = z_dim
724787 self .temperal_downsample = temperal_downsample
725788 self .temporal_upsample = temperal_downsample [::- 1 ]
789+ self .latents_mean = latents_mean
790+ self .latents_std = latents_std
726791
727792 self .encoder = WanEncoder3d (
728793 rngs = rngs ,
@@ -747,16 +812,16 @@ def __init__(
747812 kernel_size = 1 ,
748813 )
749814
750- # self.decoder = WanDecoder3d(
751- # rngs=rngs,
752- # dim=base_dim,
753- # z_dim=z_dim,
754- # dim_mult=dim_mult,
755- # num_res_blocks=num_res_blocks,
756- # attn_scales=attn_scales,
757- # temperal_upsample=self.temporal_upsample,
758- # dropout=dropout
759- # )
815+ self .decoder = WanDecoder3d (
816+ rngs = rngs ,
817+ dim = base_dim ,
818+ z_dim = z_dim ,
819+ dim_mult = dim_mult ,
820+ num_res_blocks = num_res_blocks ,
821+ attn_scales = attn_scales ,
822+ temperal_upsample = self .temporal_upsample ,
823+ dropout = dropout
824+ )
760825 self .clear_cache ()
761826
762827 def clear_cache (self ):
@@ -769,9 +834,9 @@ def _count_conv3d(module):
769834 count += 1
770835 return count
771836
772- # self._conv_num = _count_conv3d(self.decoder)
773- # self._conv_idx = [0]
774- # self._feat_map = [None] * self._conv_num
837+ self ._conv_num = _count_conv3d (self .decoder )
838+ self ._conv_idx = [0 ]
839+ self ._feat_map = [None ] * self ._conv_num
775840 # cache encode
776841 self ._enc_conv_num = _count_conv3d (self .encoder )
777842 self ._enc_conv_idx = [0 ]
@@ -817,4 +882,35 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode
817882 if not return_dict :
818883 return (posterior , )
819884 return FlaxAutoencoderKLOutput (latent_dist = posterior )
885+
886+ def _decode (self , z : jax .Array , return_dict : bool = True ) -> Union [FlaxDecoderOutput , jax .Array ]:
887+ self .clear_cache ()
888+ iter_ = z .shape [1 ]
889+ x = self .post_quant_conv (z )
890+ for i in range (iter_ ):
891+ self ._conv_idx = [0 ]
892+ if i == 0 :
893+ out = self .decoder (x [:,i : i + 1 , :, :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
894+ else :
895+ out_ = self .decoder (x [:,i : i + 1 , :, :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
896+
897+ out = jnp .concatenate ([out , out_ ], axis = 1 )
898+
899+ out = jnp .clip (out , a_min = - 1.0 , a_max = 1.0 )
900+ self .clear_cache ()
901+ if not return_dict :
902+ return (out , )
903+
904+ return FlaxDecoderOutput (sample = out )
905+
906+ def decode (self , z : jax .Array , return_dict : bool = True ) -> Union [FlaxDecoderOutput , jax .Array ]:
907+ if z .shape [- 1 ] != self .z_dim :
908+ # reshape channel last for JAX
909+ x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
910+ assert x .shape [- 1 ] == self .z_dim , f"Expected input shape (N, D, H, W, { self .z_dim } , got { x .shape } "
911+ decoded = self ._decode (z ).sample
912+ if not return_dict :
913+ return (decoded ,)
914+ return FlaxDecoderOutput (sample = decoded )
915+
820916
0 commit comments