@@ -86,6 +86,7 @@ def __init__(
8686 cutoff = 0.5 / ratio
8787 half_width = 0.6 / ratio
8888 low_pass_filter = kaiser_sinc_filter1d (cutoff , half_width , self .kernel_size )
89+ print (f"DownSample1d filter - min: { low_pass_filter .min ()} , max: { low_pass_filter .max ()} " )
8990 self .filter = jnp .expand_dims (low_pass_filter , axis = (1 , 2 ))
9091
9192 def __call__ (self , x : Array ) -> Array :
@@ -104,6 +105,7 @@ def __call__(self, x: Array) -> Array:
104105 dimension_numbers = ('NLC' , 'LIO' , 'NLC' ),
105106 feature_group_count = num_channels ,
106107 )
108+ jax .debug .print ("DownSample1d after conv - min: {min}, max: {max}" , min = x_filtered .min (), max = x_filtered .max ())
107109 return x_filtered
108110
109111
@@ -143,6 +145,7 @@ def __init__(
143145 half_width = 0.6 / ratio ,
144146 kernel_size = self .kernel_size ,
145147 )
148+ print (f"UpSample1d filter - min: { sinc_filter .min ()} , max: { sinc_filter .max ()} " )
146149 self .filter = sinc_filter .reshape (- 1 , 1 , 1 )
147150
148151 def __call__ (self , x : Array ) -> Array :
@@ -168,6 +171,7 @@ def __call__(self, x: Array) -> Array:
168171 dimension_numbers = ('NLC' , 'LIO' , 'NLC' ),
169172 feature_group_count = num_channels ,
170173 )
174+ jax .debug .print ("UpSample1d after conv - min: {min}, max: {max}" , min = x_upsampled .min (), max = x_upsampled .max ())
171175
172176 x_upsampled = x_upsampled * self .ratio
173177 return x_upsampled [:, self .pad_left : - self .pad_right , :]
@@ -230,9 +234,13 @@ def __init__(
230234 self .downsample = DownSample1d (ratio = ratio , kernel_size = kernel_size )
231235
232236 def __call__ (self , x : Array ) -> Array :
237+ jax .debug .print ("AntiAliasAct1d input - min: {min}, max: {max}" , min = x .min (), max = x .max ())
233238 x = self .upsample (x )
239+ jax .debug .print ("AntiAliasAct1d after upsample - min: {min}, max: {max}" , min = x .min (), max = x .max ())
234240 x = self .act (x )
241+ jax .debug .print ("AntiAliasAct1d after act - min: {min}, max: {max}" , min = x .min (), max = x .max ())
235242 x = self .downsample (x )
243+ jax .debug .print ("AntiAliasAct1d after downsample - min: {min}, max: {max}" , min = x .min (), max = x .max ())
236244 return x
237245
238246
0 commit comments