5757flax .config .update ("flax_always_shard_variable" , False )
5858
5959
60- class TorchWanRMS_norm (nn .Module ):
60+ class TorchWanRMS_norm (torch . nn .Module ):
6161 r"""
6262 A custom RMS normalization layer.
6363
@@ -76,14 +76,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi
7676
7777 self .channel_first = channel_first
7878 self .scale = dim ** 0.5
79- self .gamma = nn .Parameter (torch .ones (shape ))
80- self .bias = nn .Parameter (torch .zeros (shape )) if bias else 0.0
79+ self .gamma = torch . nn .Parameter (torch .ones (shape ))
80+ self .bias = torch . nn .Parameter (torch .zeros (shape )) if bias else 0.0
8181
8282 def forward (self , x ):
8383 return F .normalize (x , dim = (1 if self .channel_first else - 1 )) * self .scale * self .gamma + self .bias
8484
8585
86- class TorchWanResample (nn .Module ):
86+ class TorchWanResample (torch . nn .Module ):
8787 r"""
8888 A custom resampling module for 2D and 3D data.
8989
@@ -104,18 +104,18 @@ def __init__(self, dim: int, mode: str) -> None:
104104
105105 # layers
106106 if mode == "upsample2d" :
107- self .resample = nn .Sequential (
108- WanUpsample (scale_factor = (2.0 , 2.0 ), mode = "nearest-exact " ), nn .Conv2d (dim , dim // 2 , 3 , padding = 1 )
107+ self .resample = torch . nn .Sequential (
108+ torch . nn . Upsample (scale_factor = (2.0 , 2.0 ), mode = "nearest" ), torch . nn .Conv2d (dim , dim // 2 , 3 , padding = 1 )
109109 )
110110 elif mode == "upsample3d" :
111111 raise Exception ("downsample3d not supported" )
112112
113113 elif mode == "downsample2d" :
114- self .resample = nn .Sequential (nn .ZeroPad2d ((0 , 1 , 0 , 1 )), nn .Conv2d (dim , dim , 3 , stride = (2 , 2 )))
114+ self .resample = torch . nn .Sequential (torch . nn .ZeroPad2d ((0 , 1 , 0 , 1 )), torch . nn .Conv2d (dim , dim , 3 , stride = (2 , 2 )))
115115 elif mode == "downsample3d" :
116116 raise Exception ("downsample3d not supported" )
117117 else :
118- self .resample = nn .Identity ()
118+ self .resample = torch . nn .Identity ()
119119
120120 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
121121 b , c , t , h , w = x .size ()
@@ -218,7 +218,7 @@ def test_zero_padded_conv(self):
218218 dim = 96
219219 kernel_size = 3
220220 stride = (2 , 2 )
221- resample = nn .Sequential (nn .ZeroPad2d ((0 , 1 , 0 , 1 )), nn .Conv2d (dim , dim , kernel_size , stride = stride ))
221+ resample = torch . nn .Sequential (torch . nn .ZeroPad2d ((0 , 1 , 0 , 1 )), torch . nn .Conv2d (dim , dim , kernel_size , stride = stride ))
222222 input_shape = (1 , 96 , 480 , 720 )
223223 input = torch .ones (input_shape )
224224 output_torch = resample (input )
0 commit comments