Skip to content

Commit 98195bd

Browse files
committed
vocoder zero fix
1 parent cddbcf1 commit 98195bd

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,24 @@ def __init__(
147147

148148
def __call__(self, x: Array) -> Array:
149149
num_channels = x.shape[-1]
150-
x = jnp.pad(x, ((0, 0), (self.pad, self.pad), (0, 0)), mode='edge')
150+
batch, length, channels = x.shape
151+
152+
# Interleave zeros (manual upsampling)
153+
x_expanded = jnp.zeros((batch, length * self.ratio, channels), dtype=x.dtype)
154+
x_expanded = x_expanded.at[:, ::self.ratio, :].set(x)
155+
156+
# Pad the expanded signal
157+
pad_len = self.pad * self.ratio
158+
x_padded = jnp.pad(x_expanded, ((0, 0), (pad_len, pad_len), (0, 0)), mode='edge')
151159

152160
filter_expanded = jnp.repeat(self.filter, num_channels, axis=2)
153161
filter_expanded = filter_expanded.astype(x.dtype)
154162

155163
x_upsampled = jax.lax.conv_general_dilated(
156-
x,
164+
x_padded,
157165
filter_expanded,
158166
window_strides=(1,),
159167
padding=((0, 0),),
160-
lhs_dilation=(self.ratio,),
161168
dimension_numbers=('NLC', 'LIO', 'NLC'),
162169
feature_group_count=num_channels,
163170
)

0 commit comments

Comments
 (0)