Skip to content

Commit 2716aaf

Browse files
committed
Fix dtype mismatch in vocoder convolutions
1 parent f38ce62 commit 2716aaf

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __call__(self, x: Array) -> Array:
9494
x = jnp.pad(x, ((0, 0), (self.pad_left, self.pad_right), (0, 0)), mode='edge')
9595

9696
filter_expanded = jnp.repeat(self.filter, num_channels, axis=2)
97+
filter_expanded = filter_expanded.astype(x.dtype)
9798

9899
x_filtered = jax.lax.conv_general_dilated(
99100
x,
@@ -149,6 +150,7 @@ def __call__(self, x: Array) -> Array:
149150
x = jnp.pad(x, ((0, 0), (self.pad, self.pad), (0, 0)), mode='edge')
150151

151152
filter_expanded = jnp.repeat(self.filter, num_channels, axis=2)
153+
filter_expanded = filter_expanded.astype(x.dtype)
152154

153155
x_upsampled = jax.lax.conv_general_dilated(
154156
x,
@@ -486,6 +488,7 @@ def __call__(self, waveform: Array) -> tuple[Array, Array]:
486488

487489
left_pad = max(0, self.window_length - self.hop_length)
488490
waveform = jnp.pad(waveform, ((0, 0), (left_pad, 0), (0, 0)))
491+
waveform = waveform.astype(self.forward_basis.value.dtype)
489492

490493
spec = jax.lax.conv_general_dilated(
491494
waveform,

0 commit comments

Comments
 (0)