@@ -161,6 +161,7 @@ def __call__(self, x: Array) -> Array:
161161 dimension_numbers = ('NLC' , 'LIO' , 'NLC' ),
162162 feature_group_count = num_channels ,
163163 )
164+ jax .debug .print ("UpSample1d after conv - min: {min}, max: {max}" , min = x_upsampled .min (), max = x_upsampled .max ())
164165
165166 x_upsampled = x_upsampled * self .ratio
166167 return x_upsampled [:, self .pad_left : - self .pad_right , :]
@@ -207,6 +208,8 @@ def __call__(self, hidden_states: Array) -> Array:
207208 alpha = jnp .expand_dims (alpha , axis = 0 )
208209 amplitude = jnp .expand_dims (amplitude , axis = 0 )
209210
211+ jax .debug .print ("SnakeBeta alpha - min: {min}, max: {max}" , min = alpha .min (), max = alpha .max ())
212+ jax .debug .print ("SnakeBeta amplitude - min: {min}, max: {max}" , min = amplitude .min (), max = amplitude .max ())
210213 hidden_states = hidden_states + (1.0 / (amplitude + self .eps )) * jnp .sin (hidden_states * alpha ) ** 2
211214 return hidden_states
212215
@@ -223,8 +226,11 @@ def __init__(
223226 self .downsample = DownSample1d (ratio = ratio , kernel_size = kernel_size )
224227
225228 def __call__ (self , x : Array ) -> Array :
229+ jax .debug .print ("AntiAliasAct1d input - min: {min}, max: {max}" , min = x .min (), max = x .max ())
226230 x = self .upsample (x )
231+ jax .debug .print ("AntiAliasAct1d after upsample - min: {min}, max: {max}" , min = x .min (), max = x .max ())
227232 x = self .act (x )
233+ jax .debug .print ("AntiAliasAct1d after act - min: {min}, max: {max}" , min = x .min (), max = x .max ())
228234 x = self .downsample (x )
229235 return x
230236
@@ -443,21 +449,17 @@ def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
443449
444450 if not time_last :
445451 hidden_states = jnp .transpose (hidden_states , (0 , 1 , 3 , 2 ))
446- print (f"Transposed hidden_states - shape: { hidden_states .shape } " )
447452
448453 batch , channels , mel_bins , time = hidden_states .shape
449454 hidden_states = hidden_states .reshape (batch , channels * mel_bins , time )
450455 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 1 ))
451- print (f"Prepared hidden_states for conv_in - shape: { hidden_states .shape } " )
452456
453457 hidden_states = self .conv_in (hidden_states )
454- print (f"After conv_in - shape: { hidden_states .shape } , min: { hidden_states .min ()} , max: { hidden_states .max ()} " )
455458
456459 for i in range (self .num_upsample_layers ):
457460 if self .act_fn == "leaky_relu" :
458461 hidden_states = jax .nn .leaky_relu (hidden_states , negative_slope = self .negative_slope )
459462 hidden_states = self .upsamplers [i ](hidden_states )
460- print (f"After upsampler { i } - shape: { hidden_states .shape } , min: { hidden_states .min ()} , max: { hidden_states .max ()} " )
461463
462464 start = i * self .resnets_per_upsample
463465 end = (i + 1 ) * self .resnets_per_upsample
@@ -467,7 +469,6 @@ def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
467469 res_sum = res_sum + self .resnets [j ](hidden_states )
468470
469471 hidden_states = res_sum / self .resnets_per_upsample
470- print (f"After resnets level { i } - shape: { hidden_states .shape } , min: { hidden_states .min ()} , max: { hidden_states .max ()} " )
471472
472473 hidden_states = self .act_out (hidden_states )
473474 jax .debug .print ("After act_out - min: {min}, max: {max}" , min = hidden_states .min (), max = hidden_states .max ())
0 commit comments