@@ -168,7 +168,6 @@ def __call__(self, x: Array) -> Array:
168168 dimension_numbers = ('NLC' , 'LIO' , 'NLC' ),
169169 feature_group_count = num_channels ,
170170 )
171- jax .debug .print ("UpSample1d after conv - min: {min}, max: {max}" , min = x_upsampled .min (), max = x_upsampled .max ())
172171
173172 x_upsampled = x_upsampled * self .ratio
174173 return x_upsampled [:, self .pad_left : - self .pad_right , :]
@@ -215,8 +214,6 @@ def __call__(self, hidden_states: Array) -> Array:
215214 alpha = jnp .expand_dims (alpha , axis = 0 )
216215 amplitude = jnp .expand_dims (amplitude , axis = 0 )
217216
218- jax .debug .print ("SnakeBeta alpha - min: {min}, max: {max}" , min = alpha .min (), max = alpha .max ())
219- jax .debug .print ("SnakeBeta amplitude - min: {min}, max: {max}" , min = amplitude .min (), max = amplitude .max ())
220217 hidden_states = hidden_states + (1.0 / (amplitude + self .eps )) * jnp .sin (hidden_states * alpha ) ** 2
221218 return hidden_states
222219
@@ -233,11 +230,8 @@ def __init__(
233230 self .downsample = DownSample1d (ratio = ratio , kernel_size = kernel_size )
234231
235232 def __call__ (self , x : Array ) -> Array :
236- jax .debug .print ("AntiAliasAct1d input - min: {min}, max: {max}" , min = x .min (), max = x .max ())
237233 x = self .upsample (x )
238- jax .debug .print ("AntiAliasAct1d after upsample - min: {min}, max: {max}" , min = x .min (), max = x .max ())
239234 x = self .act (x )
240- jax .debug .print ("AntiAliasAct1d after act - min: {min}, max: {max}" , min = x .min (), max = x .max ())
241235 x = self .downsample (x )
242236 return x
243237
0 commit comments