@@ -632,32 +632,53 @@ def __init__(
632632 )
633633
634634 def __call__ (self , mel_spec : Array ) -> Array :
635+ print (f"=== BWE Vocoder Debug ===" )
636+ print (f"Input mel_spec - shape: { mel_spec .shape } , min: { mel_spec .min ()} , max: { mel_spec .max ()} " )
637+
635638 x = self .vocoder (mel_spec )
639+ print (f"Base vocoder output (x) - shape: { x .shape } , min: { x .min ()} , max: { x .max ()} " )
640+
636641 x = jnp .transpose (x , (0 , 2 , 1 ))
637642 batch_size , num_samples , num_channels = x .shape
643+ print (f"Transposed x - shape: { x .shape } " )
638644
639645 remainder = num_samples % self .hop_length
640646 if remainder != 0 :
641647 x = jnp .pad (x , ((0 , 0 ), (0 , self .hop_length - remainder ), (0 , 0 )))
648+ print (f"Padded x - shape: { x .shape } " )
642649
643650 x_flattened = x .transpose (0 , 2 , 1 ).reshape (- 1 , x .shape [1 ], 1 )
651+ print (f"x_flattened - shape: { x_flattened .shape } " )
652+
644653 log_mel , _ , _ , _ = self .mel_stft (x_flattened )
654+ print (f"MelSTFT output (log_mel) before reshape - shape: { log_mel .shape } , min: { log_mel .min ()} , max: { log_mel .max ()} " )
655+
645656 log_mel = log_mel .reshape (batch_size , num_channels , - 1 , log_mel .shape [- 1 ])
657+ print (f"Reshaped log_mel - shape: { log_mel .shape } " )
646658
647659 residual = self .bwe_generator (log_mel , time_last = False )
660+ print (f"BWE generator output (residual) - shape: { residual .shape } , min: { residual .min ()} , max: { residual .max ()} " )
648661
649662 skip = self .resampler (x )
663+ print (f"Resampler output (skip) - shape: { skip .shape } , min: { skip .min ()} , max: { skip .max ()} " )
664+
650665 residual = jnp .transpose (residual , (0 , 2 , 1 ))
651666
652667 if residual .shape [1 ] < skip .shape [1 ]:
653668 residual = jnp .pad (residual , ((0 , 0 ), (0 , skip .shape [1 ] - residual .shape [1 ]), (0 , 0 )), mode = 'edge' )
654669 elif residual .shape [1 ] > skip .shape [1 ]:
655670 residual = residual [:, :skip .shape [1 ], :]
671+ print (f"Matched residual - shape: { residual .shape } " )
656672
657- waveform = jnp .clip (residual + skip , - 1 , 1 )
673+ raw_waveform = residual + skip
674+ print (f"Raw waveform (residual + skip) - min: { raw_waveform .min ()} , max: { raw_waveform .max ()} " )
675+
676+ waveform = jnp .clip (raw_waveform , - 1 , 1 )
658677
659678 output_samples = num_samples * self .output_sampling_rate // self .input_sampling_rate
660679 waveform = waveform [:, :output_samples , :]
661680 waveform = jnp .transpose (waveform , (0 , 2 , 1 ))
681+ print (f"Final waveform - shape: { waveform .shape } , min: { waveform .min ()} , max: { waveform .max ()} " )
682+ print (f"=========================" )
662683
663684 return waveform
0 commit comments