Skip to content

Commit 9b7cf8f

Browse files
committed
Remove verbose debugs, keep Kaiser instrumentation
1 parent 066eacb commit 9b7cf8f

1 file changed

Lines changed: 18 additions & 14 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,35 @@ def kaiser_window(n: int, beta: float) -> Array:
3333
alpha = (n - 1) / 2.0
3434
time = jnp.arange(n)
3535
term = beta * jnp.sqrt(1 - ((time - alpha) / alpha) ** 2)
36-
return jss.i0(term) / jss.i0(beta)
36+
jax.debug.print("kaiser_window term - min: {min}, max: {max}", min=term.min(), max=term.max())
37+
38+
i0_term = jss.i0(term)
39+
i0_beta = jss.i0(beta)
40+
jax.debug.print("kaiser_window i0_term - min: {min}, max: {max}", min=i0_term.min(), max=i0_term.max())
41+
jax.debug.print("kaiser_window i0_beta: {val}", val=i0_beta)
42+
43+
res = i0_term / i0_beta
44+
return res
3745

3846
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> Array:
3947
"""Creates a Kaiser sinc kernel for low-pass filtering."""
4048
delta_f = 4 * half_width
4149
half_size = kernel_size // 2
4250
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
51+
52+
print(f"kaiser_sinc_filter1d amplitude: {amplitude}")
53+
4354
if amplitude > 50.0:
4455
beta = 0.1102 * (amplitude - 8.7)
4556
elif amplitude >= 21.0:
4657
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
4758
else:
4859
beta = 0.0
4960

61+
print(f"kaiser_sinc_filter1d beta: {beta}")
62+
5063
window = kaiser_window(kernel_size, beta)
64+
jax.debug.print("kaiser_sinc_filter1d window - min: {min}, max: {max}", min=window.min(), max=window.max())
5165

5266
even = kernel_size % 2 == 0
5367
time = jnp.arange(-half_size, half_size) + 0.5 if even else jnp.arange(kernel_size) - half_size
@@ -61,7 +75,10 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
6175
jnp.ones_like(time),
6276
jnp.sin(math.pi * time) / math.pi / time,
6377
)
78+
jax.debug.print("kaiser_sinc_filter1d sinc - min: {min}, max: {max}", min=sinc.min(), max=sinc.max())
79+
6480
filter = 2 * cutoff * window * sinc
81+
jax.debug.print("kaiser_sinc_filter1d before norm - min: {min}, max: {max}, sum: {sum}", min=filter.min(), max=filter.max(), sum=filter.sum())
6582
filter = filter / filter.sum()
6683
return filter
6784

@@ -152,25 +169,17 @@ def __call__(self, x: Array) -> Array:
152169
num_channels = x.shape[-1]
153170
batch, length, channels = x.shape
154171

155-
jax.debug.print("UpSample1d input - min: {min}, max: {max}", min=x.min(), max=x.max())
156-
157172
# Interleave zeros (manual upsampling)
158173
x_expanded = jnp.zeros((batch, length * self.ratio, channels), dtype=x.dtype)
159174
x_expanded = x_expanded.at[:, ::self.ratio, :].set(x)
160175

161-
jax.debug.print("UpSample1d after interleave - min: {min}, max: {max}", min=x_expanded.min(), max=x_expanded.max())
162-
163176
# Pad the expanded signal
164177
pad_len = self.pad * self.ratio
165178
x_padded = jnp.pad(x_expanded, ((0, 0), (pad_len, pad_len), (0, 0)), mode='edge')
166179

167-
jax.debug.print("UpSample1d after pad - min: {min}, max: {max}", min=x_padded.min(), max=x_padded.max())
168-
169180
filter_expanded = jnp.repeat(self.filter, num_channels, axis=2)
170181
filter_expanded = filter_expanded.astype(x.dtype)
171182

172-
jax.debug.print("UpSample1d filter applied - min: {min}, max: {max}", min=filter_expanded.min(), max=filter_expanded.max())
173-
174183
x_upsampled = jax.lax.conv_general_dilated(
175184
x_padded,
176185
filter_expanded,
@@ -179,7 +188,6 @@ def __call__(self, x: Array) -> Array:
179188
dimension_numbers=('NLC', 'LIO', 'NLC'),
180189
feature_group_count=num_channels,
181190
)
182-
jax.debug.print("UpSample1d after conv - min: {min}, max: {max}", min=x_upsampled.min(), max=x_upsampled.max())
183191

184192
x_upsampled = x_upsampled * self.ratio
185193
return x_upsampled[:, self.pad_left : -self.pad_right, :]
@@ -242,13 +250,9 @@ def __init__(
242250
self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size)
243251

244252
def __call__(self, x: Array) -> Array:
245-
jax.debug.print("AntiAliasAct1d input - min: {min}, max: {max}", min=x.min(), max=x.max())
246253
x = self.upsample(x)
247-
jax.debug.print("AntiAliasAct1d after upsample - min: {min}, max: {max}", min=x.min(), max=x.max())
248254
x = self.act(x)
249-
jax.debug.print("AntiAliasAct1d after act - min: {min}, max: {max}", min=x.min(), max=x.max())
250255
x = self.downsample(x)
251-
jax.debug.print("AntiAliasAct1d after downsample - min: {min}, max: {max}", min=x.min(), max=x.max())
252256
return x
253257

254258

0 commit comments

Comments
 (0)