Skip to content

Commit 9541cc1

Browse files
committed
Add comprehensive debug prints to BWE vocoder
1 parent 728a6d1 commit 9541cc1

1 file changed

Lines changed: 22 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,32 +632,53 @@ def __init__(
632632
)
633633

634634
def __call__(self, mel_spec: Array) -> Array:
635+
print(f"=== BWE Vocoder Debug ===")
636+
print(f"Input mel_spec - shape: {mel_spec.shape}, min: {mel_spec.min()}, max: {mel_spec.max()}")
637+
635638
x = self.vocoder(mel_spec)
639+
print(f"Base vocoder output (x) - shape: {x.shape}, min: {x.min()}, max: {x.max()}")
640+
636641
x = jnp.transpose(x, (0, 2, 1))
637642
batch_size, num_samples, num_channels = x.shape
643+
print(f"Transposed x - shape: {x.shape}")
638644

639645
remainder = num_samples % self.hop_length
640646
if remainder != 0:
641647
x = jnp.pad(x, ((0, 0), (0, self.hop_length - remainder), (0, 0)))
648+
print(f"Padded x - shape: {x.shape}")
642649

643650
x_flattened = x.transpose(0, 2, 1).reshape(-1, x.shape[1], 1)
651+
print(f"x_flattened - shape: {x_flattened.shape}")
652+
644653
log_mel, _, _, _ = self.mel_stft(x_flattened)
654+
print(f"MelSTFT output (log_mel) before reshape - shape: {log_mel.shape}, min: {log_mel.min()}, max: {log_mel.max()}")
655+
645656
log_mel = log_mel.reshape(batch_size, num_channels, -1, log_mel.shape[-1])
657+
print(f"Reshaped log_mel - shape: {log_mel.shape}")
646658

647659
residual = self.bwe_generator(log_mel, time_last=False)
660+
print(f"BWE generator output (residual) - shape: {residual.shape}, min: {residual.min()}, max: {residual.max()}")
648661

649662
skip = self.resampler(x)
663+
print(f"Resampler output (skip) - shape: {skip.shape}, min: {skip.min()}, max: {skip.max()}")
664+
650665
residual = jnp.transpose(residual, (0, 2, 1))
651666

652667
if residual.shape[1] < skip.shape[1]:
653668
residual = jnp.pad(residual, ((0, 0), (0, skip.shape[1] - residual.shape[1]), (0, 0)), mode='edge')
654669
elif residual.shape[1] > skip.shape[1]:
655670
residual = residual[:, :skip.shape[1], :]
671+
print(f"Matched residual - shape: {residual.shape}")
656672

657-
waveform = jnp.clip(residual + skip, -1, 1)
673+
raw_waveform = residual + skip
674+
print(f"Raw waveform (residual + skip) - min: {raw_waveform.min()}, max: {raw_waveform.max()}")
675+
676+
waveform = jnp.clip(raw_waveform, -1, 1)
658677

659678
output_samples = num_samples * self.output_sampling_rate // self.input_sampling_rate
660679
waveform = waveform[:, :output_samples, :]
661680
waveform = jnp.transpose(waveform, (0, 2, 1))
681+
print(f"Final waveform - shape: {waveform.shape}, min: {waveform.min()}, max: {waveform.max()}")
682+
print(f"=========================")
662683

663684
return waveform

0 commit comments

Comments
 (0)