Skip to content

Commit 728a6d1

Browse files
committed
Add shape matching workaround for BWE residual and skip
1 parent 2716aaf commit 728a6d1

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,12 @@ def __call__(self, mel_spec: Array) -> Array:
648648

649649
skip = self.resampler(x)
650650
residual = jnp.transpose(residual, (0, 2, 1))
651+
652+
if residual.shape[1] < skip.shape[1]:
653+
residual = jnp.pad(residual, ((0, 0), (0, skip.shape[1] - residual.shape[1]), (0, 0)), mode='edge')
654+
elif residual.shape[1] > skip.shape[1]:
655+
residual = residual[:, :skip.shape[1], :]
656+
651657
waveform = jnp.clip(residual + skip, -1, 1)
652658

653659
output_samples = num_samples * self.output_sampling_rate // self.input_sampling_rate

0 commit comments

Comments
 (0)