Skip to content

Commit dbe9ae4

Browse files
committed
vocoder bwe changes
1 parent 4aa4158 commit dbe9ae4

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_bwe_ltx2.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,20 @@ def __call__(self, x: Array) -> Array:
218218
w_transposed = jnp.transpose(self.filter_weight, (2, 1, 0)) # (C, 1, K)
219219
w_flipped = w_transposed[..., ::-1]
220220

221+
# Manual dilation to match PyTorch ConvTranspose1d behavior
222+
b, c, t = x_transposed.shape
223+
x_dilated = jnp.zeros((b, c, (t - 1) * self.ratio + 1), dtype=x.dtype)
224+
x_dilated = x_dilated.at[:, :, ::self.ratio].set(x_transposed)
225+
226+
# Pad with 2 * (K - 1) zeros on the right to match PyTorch output length
227+
pad_len = 2 * (self.kernel_size - 1)
228+
x_dilated = jnp.pad(x_dilated, ((0, 0), (0, 0), (0, pad_len)))
229+
221230
out = jax.lax.conv_general_dilated(
222-
lhs=x_transposed,
231+
lhs=x_dilated,
223232
rhs=w_flipped,
224233
window_strides=(1,),
225234
padding=((0, 0),),
226-
lhs_dilation=(self.ratio,),
227235
feature_group_count=self.channels,
228236
dimension_numbers=jax.lax.ConvDimensionNumbers(
229237
lhs_spec=(0, 1, 2), # N, C, W

0 commit comments

Comments
 (0)