@@ -147,17 +147,24 @@ def __init__(
147147
148148 def __call__ (self , x : Array ) -> Array :
149149 num_channels = x .shape [- 1 ]
150- x = jnp .pad (x , ((0 , 0 ), (self .pad , self .pad ), (0 , 0 )), mode = 'edge' )
150+ batch , length , channels = x .shape
151+
152+ # Interleave zeros (manual upsampling)
153+ x_expanded = jnp .zeros ((batch , length * self .ratio , channels ), dtype = x .dtype )
154+ x_expanded = x_expanded .at [:, ::self .ratio , :].set (x )
155+
156+ # Pad the expanded signal
157+ pad_len = self .pad * self .ratio
158+ x_padded = jnp .pad (x_expanded , ((0 , 0 ), (pad_len , pad_len ), (0 , 0 )), mode = 'edge' )
151159
152160 filter_expanded = jnp .repeat (self .filter , num_channels , axis = 2 )
153161 filter_expanded = filter_expanded .astype (x .dtype )
154162
155163 x_upsampled = jax .lax .conv_general_dilated (
156- x ,
164+ x_padded ,
157165 filter_expanded ,
158166 window_strides = (1 ,),
159167 padding = ((0 , 0 ),),
160- lhs_dilation = (self .ratio ,),
161168 dimension_numbers = ('NLC' , 'LIO' , 'NLC' ),
162169 feature_group_count = num_channels ,
163170 )
0 commit comments