Skip to content

Commit cddbcf1

Browse files
committed
vocoder debug
1 parent bdcdeb1 commit cddbcf1

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __call__(self, x: Array) -> Array:
161161
dimension_numbers=('NLC', 'LIO', 'NLC'),
162162
feature_group_count=num_channels,
163163
)
164+
jax.debug.print("UpSample1d after conv - min: {min}, max: {max}", min=x_upsampled.min(), max=x_upsampled.max())
164165

165166
x_upsampled = x_upsampled * self.ratio
166167
return x_upsampled[:, self.pad_left : -self.pad_right, :]
@@ -207,6 +208,8 @@ def __call__(self, hidden_states: Array) -> Array:
207208
alpha = jnp.expand_dims(alpha, axis=0)
208209
amplitude = jnp.expand_dims(amplitude, axis=0)
209210

211+
jax.debug.print("SnakeBeta alpha - min: {min}, max: {max}", min=alpha.min(), max=alpha.max())
212+
jax.debug.print("SnakeBeta amplitude - min: {min}, max: {max}", min=amplitude.min(), max=amplitude.max())
210213
hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * jnp.sin(hidden_states * alpha) ** 2
211214
return hidden_states
212215

@@ -223,8 +226,11 @@ def __init__(
223226
self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size)
224227

225228
def __call__(self, x: Array) -> Array:
229+
jax.debug.print("AntiAliasAct1d input - min: {min}, max: {max}", min=x.min(), max=x.max())
226230
x = self.upsample(x)
231+
jax.debug.print("AntiAliasAct1d after upsample - min: {min}, max: {max}", min=x.min(), max=x.max())
227232
x = self.act(x)
233+
jax.debug.print("AntiAliasAct1d after act - min: {min}, max: {max}", min=x.min(), max=x.max())
228234
x = self.downsample(x)
229235
return x
230236

@@ -443,21 +449,17 @@ def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
443449

444450
if not time_last:
445451
hidden_states = jnp.transpose(hidden_states, (0, 1, 3, 2))
446-
print(f"Transposed hidden_states - shape: {hidden_states.shape}")
447452

448453
batch, channels, mel_bins, time = hidden_states.shape
449454
hidden_states = hidden_states.reshape(batch, channels * mel_bins, time)
450455
hidden_states = jnp.transpose(hidden_states, (0, 2, 1))
451-
print(f"Prepared hidden_states for conv_in - shape: {hidden_states.shape}")
452456

453457
hidden_states = self.conv_in(hidden_states)
454-
print(f"After conv_in - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
455458

456459
for i in range(self.num_upsample_layers):
457460
if self.act_fn == "leaky_relu":
458461
hidden_states = jax.nn.leaky_relu(hidden_states, negative_slope=self.negative_slope)
459462
hidden_states = self.upsamplers[i](hidden_states)
460-
print(f"After upsampler {i} - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
461463

462464
start = i * self.resnets_per_upsample
463465
end = (i + 1) * self.resnets_per_upsample
@@ -467,7 +469,6 @@ def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
467469
res_sum = res_sum + self.resnets[j](hidden_states)
468470

469471
hidden_states = res_sum / self.resnets_per_upsample
470-
print(f"After resnets level {i} - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
471472

472473
hidden_states = self.act_out(hidden_states)
473474
jax.debug.print("After act_out - min: {min}, max: {max}", min=hidden_states.min(), max=hidden_states.max())

0 commit comments

Comments
 (0)