@@ -100,14 +100,13 @@ def __init__(
100100 self .use_padding = use_padding
101101 self .padding_mode = padding_mode
102102
103- cutoff = 0.5 / ratio
104- half_width = 0.6 / ratio
105- low_pass_filter = kaiser_sinc_filter1d (cutoff , half_width , self .kernel_size )
106- print (f"DownSample1d filter - min: { low_pass_filter .min ()} , max: { low_pass_filter .max ()} " )
107- self .filter = jnp .expand_dims (low_pass_filter , axis = (1 , 2 ))
108-
109103 def __call__ (self , x : Array ) -> Array :
110104 num_channels = x .shape [- 1 ]
105+
106+ cutoff = 0.5 / self .ratio
107+ half_width = 0.6 / self .ratio
108+ low_pass_filter = kaiser_sinc_filter1d (cutoff , half_width , self .kernel_size )
109+ filter = jnp .expand_dims (low_pass_filter , axis = (1 , 2 ))
111110 if self .use_padding :
112111 x = jnp .pad (x , ((0 , 0 ), (self .pad_left , self .pad_right ), (0 , 0 )), mode = 'edge' )
113112
@@ -137,6 +136,8 @@ def __init__(
137136 self .ratio = ratio
138137 self .padding_mode = padding_mode
139138
139+ self .window_type = window_type
140+
140141 if window_type == "hann" :
141142 rolloff = 0.99
142143 lowpass_filter_width = 6
@@ -145,30 +146,33 @@ def __init__(
145146 self .pad = width
146147 self .pad_left = 2 * width * ratio
147148 self .pad_right = self .kernel_size - ratio
148-
149- time_axis = (jnp .arange (self .kernel_size ) / ratio - width ) * rolloff
150- time_clamped = jnp .clip (time_axis , - lowpass_filter_width , lowpass_filter_width )
151- window = jnp .cos (time_clamped * math .pi / lowpass_filter_width / 2 ) ** 2
152- sinc_filter = jnp .sinc (time_axis ) * window * rolloff / ratio
153- self .filter = sinc_filter .reshape (- 1 , 1 , 1 )
154149 else :
155150 self .kernel_size = int (6 * ratio // 2 ) * 2 if kernel_size is None else kernel_size
156151 self .pad = self .kernel_size // ratio - 1
157152 self .pad_left = self .pad * self .ratio + (self .kernel_size - self .ratio ) // 2
158153 self .pad_right = self .pad * self .ratio + (self .kernel_size - self .ratio + 1 ) // 2
159154
160- sinc_filter = kaiser_sinc_filter1d (
161- cutoff = 0.5 / ratio ,
162- half_width = 0.6 / ratio ,
163- kernel_size = self .kernel_size ,
164- )
165- print (f"UpSample1d filter - min: { sinc_filter .min ()} , max: { sinc_filter .max ()} " )
166- self .filter = sinc_filter .reshape (- 1 , 1 , 1 )
167-
168155 def __call__ (self , x : Array ) -> Array :
169156 num_channels = x .shape [- 1 ]
170157 batch , length , channels = x .shape
171158
159+ if self .window_type == "hann" :
160+ rolloff = 0.99
161+ lowpass_filter_width = 6
162+ width = math .ceil (lowpass_filter_width / rolloff )
163+ time_axis = (jnp .arange (self .kernel_size ) / self .ratio - width ) * rolloff
164+ time_clamped = jnp .clip (time_axis , - lowpass_filter_width , lowpass_filter_width )
165+ window = jnp .cos (time_clamped * math .pi / lowpass_filter_width / 2 ) ** 2
166+ sinc_filter = jnp .sinc (time_axis ) * window * rolloff / self .ratio
167+ filter = sinc_filter .reshape (- 1 , 1 , 1 )
168+ else :
169+ sinc_filter = kaiser_sinc_filter1d (
170+ cutoff = 0.5 / self .ratio ,
171+ half_width = 0.6 / self .ratio ,
172+ kernel_size = self .kernel_size ,
173+ )
174+ filter = sinc_filter .reshape (- 1 , 1 , 1 )
175+
172176 # Interleave zeros (manual upsampling)
173177 x_expanded = jnp .zeros ((batch , length * self .ratio , channels ), dtype = x .dtype )
174178 x_expanded = x_expanded .at [:, ::self .ratio , :].set (x )
@@ -177,7 +181,7 @@ def __call__(self, x: Array) -> Array:
177181 pad_len = self .pad * self .ratio
178182 x_padded = jnp .pad (x_expanded , ((0 , 0 ), (pad_len , pad_len ), (0 , 0 )), mode = 'edge' )
179183
180- filter_expanded = jnp .repeat (self . filter , num_channels , axis = 2 )
184+ filter_expanded = jnp .repeat (filter , num_channels , axis = 2 )
181185 filter_expanded = filter_expanded .astype (x .dtype )
182186
183187 x_upsampled = jax .lax .conv_general_dilated (
0 commit comments