Skip to content

Commit df5b8a5

Browse files
committed
Comprehensive trace for AntiAliasAct1d sub-components
1 parent b5ae172 commit df5b8a5

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
cutoff = 0.5 / ratio
8787
half_width = 0.6 / ratio
8888
low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size)
89+
print(f"DownSample1d filter - min: {low_pass_filter.min()}, max: {low_pass_filter.max()}")
8990
self.filter = jnp.expand_dims(low_pass_filter, axis=(1, 2))
9091

9192
def __call__(self, x: Array) -> Array:
@@ -104,6 +105,7 @@ def __call__(self, x: Array) -> Array:
104105
dimension_numbers=('NLC', 'LIO', 'NLC'),
105106
feature_group_count=num_channels,
106107
)
108+
jax.debug.print("DownSample1d after conv - min: {min}, max: {max}", min=x_filtered.min(), max=x_filtered.max())
107109
return x_filtered
108110

109111

@@ -143,6 +145,7 @@ def __init__(
143145
half_width=0.6 / ratio,
144146
kernel_size=self.kernel_size,
145147
)
148+
print(f"UpSample1d filter - min: {sinc_filter.min()}, max: {sinc_filter.max()}")
146149
self.filter = sinc_filter.reshape(-1, 1, 1)
147150

148151
def __call__(self, x: Array) -> Array:
@@ -168,6 +171,7 @@ def __call__(self, x: Array) -> Array:
168171
dimension_numbers=('NLC', 'LIO', 'NLC'),
169172
feature_group_count=num_channels,
170173
)
174+
jax.debug.print("UpSample1d after conv - min: {min}, max: {max}", min=x_upsampled.min(), max=x_upsampled.max())
171175

172176
x_upsampled = x_upsampled * self.ratio
173177
return x_upsampled[:, self.pad_left : -self.pad_right, :]
@@ -230,9 +234,13 @@ def __init__(
230234
self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size)
231235

232236
def __call__(self, x: Array) -> Array:
237+
jax.debug.print("AntiAliasAct1d input - min: {min}, max: {max}", min=x.min(), max=x.max())
233238
x = self.upsample(x)
239+
jax.debug.print("AntiAliasAct1d after upsample - min: {min}, max: {max}", min=x.min(), max=x.max())
234240
x = self.act(x)
241+
jax.debug.print("AntiAliasAct1d after act - min: {min}, max: {max}", min=x.min(), max=x.max())
235242
x = self.downsample(x)
243+
jax.debug.print("AntiAliasAct1d after downsample - min: {min}, max: {max}", min=x.min(), max=x.max())
236244
return x
237245

238246

0 commit comments

Comments
 (0)