@@ -33,21 +33,35 @@ def kaiser_window(n: int, beta: float) -> Array:
3333 alpha = (n - 1 ) / 2.0
3434 time = jnp .arange (n )
3535 term = beta * jnp .sqrt (1 - ((time - alpha ) / alpha ) ** 2 )
36- return jss .i0 (term ) / jss .i0 (beta )
36+ jax .debug .print ("kaiser_window term - min: {min}, max: {max}" , min = term .min (), max = term .max ())
37+
38+ i0_term = jss .i0 (term )
39+ i0_beta = jss .i0 (beta )
40+ jax .debug .print ("kaiser_window i0_term - min: {min}, max: {max}" , min = i0_term .min (), max = i0_term .max ())
41+ jax .debug .print ("kaiser_window i0_beta: {val}" , val = i0_beta )
42+
43+ res = i0_term / i0_beta
44+ return res
3745
3846def kaiser_sinc_filter1d (cutoff : float , half_width : float , kernel_size : int ) -> Array :
3947 """Creates a Kaiser sinc kernel for low-pass filtering."""
4048 delta_f = 4 * half_width
4149 half_size = kernel_size // 2
4250 amplitude = 2.285 * (half_size - 1 ) * math .pi * delta_f + 7.95
51+
52+ print (f"kaiser_sinc_filter1d amplitude: { amplitude } " )
53+
4354 if amplitude > 50.0 :
4455 beta = 0.1102 * (amplitude - 8.7 )
4556 elif amplitude >= 21.0 :
4657 beta = 0.5842 * (amplitude - 21 ) ** 0.4 + 0.07886 * (amplitude - 21.0 )
4758 else :
4859 beta = 0.0
4960
61+ print (f"kaiser_sinc_filter1d beta: { beta } " )
62+
5063 window = kaiser_window (kernel_size , beta )
64+ jax .debug .print ("kaiser_sinc_filter1d window - min: {min}, max: {max}" , min = window .min (), max = window .max ())
5165
5266 even = kernel_size % 2 == 0
5367 time = jnp .arange (- half_size , half_size ) + 0.5 if even else jnp .arange (kernel_size ) - half_size
@@ -61,7 +75,10 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
6175 jnp .ones_like (time ),
6276 jnp .sin (math .pi * time ) / math .pi / time ,
6377 )
78+ jax .debug .print ("kaiser_sinc_filter1d sinc - min: {min}, max: {max}" , min = sinc .min (), max = sinc .max ())
79+
6480 filter = 2 * cutoff * window * sinc
81+ jax .debug .print ("kaiser_sinc_filter1d before norm - min: {min}, max: {max}, sum: {sum}" , min = filter .min (), max = filter .max (), sum = filter .sum ())
6582 filter = filter / filter .sum ()
6683 return filter
6784
@@ -152,25 +169,17 @@ def __call__(self, x: Array) -> Array:
152169 num_channels = x .shape [- 1 ]
153170 batch , length , channels = x .shape
154171
155- jax .debug .print ("UpSample1d input - min: {min}, max: {max}" , min = x .min (), max = x .max ())
156-
157172 # Interleave zeros (manual upsampling)
158173 x_expanded = jnp .zeros ((batch , length * self .ratio , channels ), dtype = x .dtype )
159174 x_expanded = x_expanded .at [:, ::self .ratio , :].set (x )
160175
161- jax .debug .print ("UpSample1d after interleave - min: {min}, max: {max}" , min = x_expanded .min (), max = x_expanded .max ())
162-
163176 # Pad the expanded signal
164177 pad_len = self .pad * self .ratio
165178 x_padded = jnp .pad (x_expanded , ((0 , 0 ), (pad_len , pad_len ), (0 , 0 )), mode = 'edge' )
166179
167- jax .debug .print ("UpSample1d after pad - min: {min}, max: {max}" , min = x_padded .min (), max = x_padded .max ())
168-
169180 filter_expanded = jnp .repeat (self .filter , num_channels , axis = 2 )
170181 filter_expanded = filter_expanded .astype (x .dtype )
171182
172- jax .debug .print ("UpSample1d filter applied - min: {min}, max: {max}" , min = filter_expanded .min (), max = filter_expanded .max ())
173-
174183 x_upsampled = jax .lax .conv_general_dilated (
175184 x_padded ,
176185 filter_expanded ,
@@ -179,7 +188,6 @@ def __call__(self, x: Array) -> Array:
179188 dimension_numbers = ('NLC' , 'LIO' , 'NLC' ),
180189 feature_group_count = num_channels ,
181190 )
182- jax .debug .print ("UpSample1d after conv - min: {min}, max: {max}" , min = x_upsampled .min (), max = x_upsampled .max ())
183191
184192 x_upsampled = x_upsampled * self .ratio
185193 return x_upsampled [:, self .pad_left : - self .pad_right , :]
@@ -242,13 +250,9 @@ def __init__(
242250 self .downsample = DownSample1d (ratio = ratio , kernel_size = kernel_size )
243251
244252 def __call__ (self , x : Array ) -> Array :
245- jax .debug .print ("AntiAliasAct1d input - min: {min}, max: {max}" , min = x .min (), max = x .max ())
246253 x = self .upsample (x )
247- jax .debug .print ("AntiAliasAct1d after upsample - min: {min}, max: {max}" , min = x .min (), max = x .max ())
248254 x = self .act (x )
249- jax .debug .print ("AntiAliasAct1d after act - min: {min}, max: {max}" , min = x .min (), max = x .max ())
250255 x = self .downsample (x )
251- jax .debug .print ("AntiAliasAct1d after downsample - min: {min}, max: {max}" , min = x .min (), max = x .max ())
252256 return x
253257
254258
0 commit comments