Skip to content

Commit 066eacb

Browse files
committed
Deep instrumentation inside UpSample1d
1 parent df5b8a5 commit 066eacb

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,25 @@ def __call__(self, x: Array) -> Array:
152152
num_channels = x.shape[-1]
153153
batch, length, channels = x.shape
154154

155+
jax.debug.print("UpSample1d input - min: {min}, max: {max}", min=x.min(), max=x.max())
156+
155157
# Interleave zeros (manual upsampling)
156158
x_expanded = jnp.zeros((batch, length * self.ratio, channels), dtype=x.dtype)
157159
x_expanded = x_expanded.at[:, ::self.ratio, :].set(x)
158160

161+
jax.debug.print("UpSample1d after interleave - min: {min}, max: {max}", min=x_expanded.min(), max=x_expanded.max())
162+
159163
# Pad the expanded signal
160164
pad_len = self.pad * self.ratio
161165
x_padded = jnp.pad(x_expanded, ((0, 0), (pad_len, pad_len), (0, 0)), mode='edge')
162166

167+
jax.debug.print("UpSample1d after pad - min: {min}, max: {max}", min=x_padded.min(), max=x_padded.max())
168+
163169
filter_expanded = jnp.repeat(self.filter, num_channels, axis=2)
164170
filter_expanded = filter_expanded.astype(x.dtype)
165171

172+
jax.debug.print("UpSample1d filter applied - min: {min}, max: {max}", min=filter_expanded.min(), max=filter_expanded.max())
173+
166174
x_upsampled = jax.lax.conv_general_dilated(
167175
x_padded,
168176
filter_expanded,

0 commit comments

Comments
 (0)