Skip to content

Commit bae29eb

Browse files
committed
vocoder bwe changes
1 parent dbe9ae4 commit bae29eb

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/ltx2/vocoder_bwe_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def __call__(self, y: Array) -> Tuple[Array, Array]:
489489
w_transposed = jnp.transpose(self.forward_basis.value, (2, 1, 0)) # (O, I, K)
490490

491491
spec = jax.lax.conv_general_dilated(
492-
lhs=y_transposed,
492+
lhs=y_transposed.astype(w_transposed.dtype),
493493
rhs=w_transposed,
494494
window_strides=(self.hop_length,),
495495
padding="VALID",

0 commit comments

Comments
 (0)