@@ -94,6 +94,7 @@ def __call__(self, x: Array) -> Array:
9494 x = jnp .pad (x , ((0 , 0 ), (self .pad_left , self .pad_right ), (0 , 0 )), mode = 'edge' )
9595
9696 filter_expanded = jnp .repeat (self .filter , num_channels , axis = 2 )
97+ filter_expanded = filter_expanded .astype (x .dtype )
9798
9899 x_filtered = jax .lax .conv_general_dilated (
99100 x ,
@@ -149,6 +150,7 @@ def __call__(self, x: Array) -> Array:
149150 x = jnp .pad (x , ((0 , 0 ), (self .pad , self .pad ), (0 , 0 )), mode = 'edge' )
150151
151152 filter_expanded = jnp .repeat (self .filter , num_channels , axis = 2 )
153+ filter_expanded = filter_expanded .astype (x .dtype )
152154
153155 x_upsampled = jax .lax .conv_general_dilated (
154156 x ,
@@ -486,6 +488,7 @@ def __call__(self, waveform: Array) -> tuple[Array, Array]:
486488
487489 left_pad = max (0 , self .window_length - self .hop_length )
488490 waveform = jnp .pad (waveform , ((0 , 0 ), (left_pad , 0 ), (0 , 0 )))
491+ waveform = waveform .astype (self .forward_basis .value .dtype )
489492
490493 spec = jax .lax .conv_general_dilated (
491494 waveform ,
0 commit comments