We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4aa4158 commit dbe9ae4Copy full SHA for dbe9ae4
1 file changed
src/maxdiffusion/models/ltx2/vocoder_bwe_ltx2.py
@@ -218,12 +218,20 @@ def __call__(self, x: Array) -> Array:
218
w_transposed = jnp.transpose(self.filter_weight, (2, 1, 0)) # (C, 1, K)
219
w_flipped = w_transposed[..., ::-1]
220
221
+ # Manual dilation to match PyTorch ConvTranspose1d behavior
222
+ b, c, t = x_transposed.shape
223
+ x_dilated = jnp.zeros((b, c, (t - 1) * self.ratio + 1), dtype=x.dtype)
224
+ x_dilated = x_dilated.at[:, :, ::self.ratio].set(x_transposed)
225
+
226
+ # Pad with 2 * (K - 1) zeros on the right to match PyTorch output length
227
+ pad_len = 2 * (self.kernel_size - 1)
228
+ x_dilated = jnp.pad(x_dilated, ((0, 0), (0, 0), (0, pad_len)))
229
230
out = jax.lax.conv_general_dilated(
- lhs=x_transposed,
231
+ lhs=x_dilated,
232
rhs=w_flipped,
233
window_strides=(1,),
234
padding=((0, 0),),
- lhs_dilation=(self.ratio,),
235
feature_group_count=self.channels,
236
dimension_numbers=jax.lax.ConvDimensionNumbers(
237
lhs_spec=(0, 1, 2), # N, C, W
0 commit comments