@@ -152,17 +152,25 @@ def __call__(self, x: Array) -> Array:
152152 num_channels = x .shape [- 1 ]
153153 batch , length , channels = x .shape
154154
155+ jax .debug .print ("UpSample1d input - min: {min}, max: {max}" , min = x .min (), max = x .max ())
156+
155157 # Interleave zeros (manual upsampling)
156158 x_expanded = jnp .zeros ((batch , length * self .ratio , channels ), dtype = x .dtype )
157159 x_expanded = x_expanded .at [:, ::self .ratio , :].set (x )
158160
161+ jax .debug .print ("UpSample1d after interleave - min: {min}, max: {max}" , min = x_expanded .min (), max = x_expanded .max ())
162+
159163 # Pad the expanded signal
160164 pad_len = self .pad * self .ratio
161165 x_padded = jnp .pad (x_expanded , ((0 , 0 ), (pad_len , pad_len ), (0 , 0 )), mode = 'edge' )
162166
167+ jax .debug .print ("UpSample1d after pad - min: {min}, max: {max}" , min = x_padded .min (), max = x_padded .max ())
168+
163169 filter_expanded = jnp .repeat (self .filter , num_channels , axis = 2 )
164170 filter_expanded = filter_expanded .astype (x .dtype )
165171
172+ jax .debug .print ("UpSample1d filter applied - min: {min}, max: {max}" , min = filter_expanded .min (), max = filter_expanded .max ())
173+
166174 x_upsampled = jax .lax .conv_general_dilated (
167175 x_padded ,
168176 filter_expanded ,
0 commit comments