@@ -313,10 +313,49 @@ def __init__(
313313 dropout : float = 0.0 ,
314314 non_linearity : str = "silu" ,
315315 ):
316- pass
316+ self .nonlinearity = get_activation (non_linearity )
317+
318+ # layers
319+ self .norm1 = WanRMS_norm (dim = in_dim , rngs = rngs , images = False , channel_first = False )
320+ self .conv1 = WanCausalConv3d (
321+ rngs = rngs ,
322+ in_channels = in_dim ,
323+ out_channels = out_dim ,
324+ kernel_size = 3 ,
325+ padding = 1
326+ )
327+ self .norm2 = WanRMS_norm (dim = out_dim , rngs = rngs , images = False , channel_first = False )
328+ self .dropout = nnx .Dropout (dropout , rngs = rngs )
329+ self .conv2 = WanCausalConv3d (
330+ rngs = rngs ,
331+ in_channels = out_dim ,
332+ out_channels = out_dim ,
333+ kernel_size = 3 ,
334+ padding = 1
335+ )
336+ self .conv_shortcut = WanCausalConv3d (
337+ rngs = rngs ,
338+ in_channels = in_dim ,
339+ out_channels = out_dim ,
340+ kernel_size = 1
341+ ) if in_dim != out_dim else Identity ()
342+
317343
318344 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
319- return x
345+ # Apply shortcut connection
346+ #breakpoint()
347+ h = self .conv_shortcut (x )
348+
349+ x = self .norm1 (x )
350+ x = self .nonlinearity (x )
351+ x = self .conv1 (x )
352+
353+ x = self .norm2 (x )
354+ x = self .nonlinearity (x )
355+ x = self .dropout (x )
356+ x = self .conv2 (x )
357+
358+ return x + h
320359
321360class WanAttentionBlock (nnx .Module ):
322361 def __init__ (
@@ -397,11 +436,11 @@ def __init__(
397436
398437 # init block
399438 self .conv_in = WanCausalConv3d (
439+ rngs = rngs ,
400440 in_channels = 3 ,
401441 out_channels = dims [0 ],
402442 kernel_size = 3 ,
403443 padding = 1 ,
404- rngs = rngs
405444 )
406445
407446 # downsample blocks
@@ -439,6 +478,12 @@ def __init__(
439478 )
440479
441480 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
481+ # (1, 1, 480, 720, 3)
482+ x = self .conv_in (x )
483+ # (1, 1, 480, 720, 96)
484+ for layer in self .down_blocks :
485+ x = layer (x )
486+ breakpoint ()
442487 return x
443488
444489class WanDecoder3d (nnx .Module ):
@@ -480,7 +525,13 @@ def __init__(
480525 scale = 1.0 / 2 ** (len (dim_mult ) - 2 )
481526
482527 # init block
483- self .conv_in = WanCausalConv3d (in_channels = z_dim , out_channels = dims [0 ], kernel_size = 3 , padding = 1 , rngs = rngs )
528+ self .conv_in = WanCausalConv3d (
529+ rngs = rngs ,
530+ in_channels = z_dim ,
531+ out_channels = dims [0 ],
532+ kernel_size = 3 ,
533+ padding = 1
534+ )
484535
485536 # middle_blocks
486537 self .mid_block = WanMidBlock (dim = dims [0 ], rngs = rngs , dropout = dropout , non_linearity = non_linearity , num_layers = 1 )
@@ -516,7 +567,13 @@ def __init__(
516567 # output blocks
517568 self .norm_out = nnx .RMSNorm (num_features = out_dim , )
518569 self .norm_out = WanRMS_norm (dim = out_dim , images = False , rngs = rngs )
519- self .conv_out = WanCausalConv3d (in_channels = out_dim , out_channels = 3 , kernel_size = 3 , padding = 1 , rngs = rngs )
570+ self .conv_out = WanCausalConv3d (
571+ rngs = rngs ,
572+ in_channels = out_dim ,
573+ out_channels = 3 ,
574+ kernel_size = 3 ,
575+ padding = 1
576+ )
520577
521578 def __call__ (self , x : jax .Array , feat_cache = None , feat_idx = [0 ]):
522579 breakpoint ()
@@ -533,7 +590,7 @@ def __init__(
533590 dim_mult : Tuple [int ] = [1 ,2 ,4 ,4 ],
534591 num_res_blocks : int = 2 ,
535592 attn_scales : List [float ] = [],
536- temporal_downsample : List [bool ] = [False , True , True ],
593+ temperal_downsample : List [bool ] = [False , True , True ],
537594 dropout : float = 0.0 ,
538595 latents_mean : List [float ] = [
539596 - 0.7571 ,- 0.7089 ,- 0.9113 ,0.1075 ,- 0.1745 ,0.9653 ,- 0.1517 , 1.5508 ,
@@ -545,31 +602,59 @@ def __init__(
545602 ],
546603 ):
547604 self .z_dim = z_dim
548- self .temporal_downsample = temporal_downsample
549- self .temporal_upsample = temporal_downsample [::- 1 ]
550-
551- self .encoder = WanEncoder3d (z_dim * 2 , z_dim * 2 , 1 )
552- self .quant_conv = WanCausalConv3d (z_dim * 2 , z_dim * 2 , 1 , rngs = rngs )
553- self .post_quant_conv = WanCausalConv3d (z_dim , z_dim , 1 , rngs = rngs )
605+ self .temperal_downsample = temperal_downsample
606+ self .temporal_upsample = temperal_downsample [::- 1 ]
554607
555- self .decoder = WanDecoder3d (
556- base_dim , z_dim , dim_mult , num_res_blocks , attn_scales , self .temporal_upsample , dropout
608+ self .encoder = WanEncoder3d (
609+ rngs = rngs ,
610+ dim = base_dim ,
611+ z_dim = z_dim * 2 ,
612+ dim_mult = dim_mult ,
613+ num_res_blocks = num_res_blocks ,
614+ attn_scales = attn_scales ,
615+ temperal_downsample = temperal_downsample ,
616+ dropout = dropout ,
617+ )
618+ self .quant_conv = WanCausalConv3d (
619+ rngs = rngs ,
620+ in_channels = z_dim * 2 ,
621+ out_channels = z_dim * 2 ,
622+ kernel_size = 1
623+ )
624+ self .post_quant_conv = WanCausalConv3d (
625+ rngs = rngs ,
626+ in_channels = z_dim ,
627+ out_channels = z_dim ,
628+ kernel_size = 1 ,
557629 )
630+
631+ # self.decoder = WanDecoder3d(
632+ # rngs=rngs,
633+ # dim=base_dim,
634+ # z_dim=z_dim,
635+ # dim_mult=dim_mult,
636+ # num_res_blocks=num_res_blocks,
637+ # attn_scales=attn_scales,
638+ # temperal_upsample=self.temporal_upsample,
639+ # dropout=dropout
640+ # )
558641 self .clear_cache ()
559642
560643 def clear_cache (self ):
561644 """ Resets cache dictionaries and indices"""
562645 def _count_conv3d (module ):
563646 count = 0
564- node_types = nnx .graph .iter_graph (module , nnx .Module )
565- for node in node_types :
566- if isinstance (node .value , WanCausalConv3d ):
647+ node_types = nnx .graph .iter_graph ([module ])
648+ for path , value in node_types :
649+ #breakpoint()
650+ if isinstance (value , WanCausalConv3d ):
651+ print ("value: " , value )
567652 count += 1
568653 return count
569654
570- self ._conv_num = _count_conv3d (self .decoder )
571- self ._conv_idx = [0 ]
572- self ._feat_map = [None ] * self ._conv_num
655+ # self._conv_num = _count_conv3d(self.decoder)
656+ # self._conv_idx = [0]
657+ # self._feat_map = [None] * self._conv_num
573658 # cache encode
574659 self ._enc_conv_num = _count_conv3d (self .encoder )
575660 self ._enc_conv_idx = [0 ]
@@ -581,7 +666,7 @@ def _encode(self, x: jax.Array):
581666 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
582667 assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
583668
584- self .clear_cache ()
669+ # self.clear_cache()
585670
586671 t = x .shape [1 ]
587672 iter_ = 1 + (t - 1 ) // 4
@@ -590,7 +675,7 @@ def _encode(self, x: jax.Array):
590675 out = self .encoder (
591676 x [:, :1 , :, :, :],
592677 feat_cache = self ._enc_feat_map ,
593- feat_ids = self ._enc_conv_idx
678+ feat_idx = self ._enc_conv_idx
594679 )
595680 else :
596681 out_ = self .encoder (
@@ -600,11 +685,12 @@ def _encode(self, x: jax.Array):
600685 )
601686 out = jnp .concatenate ([out , out_ ], axis = 1 )
602687
603- enc = self .quant_conv (out )
604- mu , logvar = enc [:, :, :, :, : self .z_dim ], enc [:, :, :, :, self .z_dim :]
605- enc = jnp .concatenate ([mu , logvar ], dim = 1 )
606- self .clear_cache ()
607- return enc
688+ # enc = self.quant_conv(out)
689+ # mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
690+ # enc = jnp.concatenate([mu, logvar], dim=1)
691+ # self.clear_cache()
692+ # return enc
693+ return x
608694
609695 def encode (self , x : jax .Array , return_dict : bool = True ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
610696 """ Encode video into latent distribution."""
0 commit comments