Skip to content

Commit 7501406

Browse files
committed
Add deep internal prints to find silence origin
1 parent 9541cc1 commit 7501406

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,19 +438,26 @@ def __init__(
438438
)
439439

440440
def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
441+
print(f"--- LTX2Vocoder Internal Debug ---")
442+
print(f"Input hidden_states - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
443+
441444
if not time_last:
442445
hidden_states = jnp.transpose(hidden_states, (0, 1, 3, 2))
446+
print(f"Transposed hidden_states - shape: {hidden_states.shape}")
443447

444448
batch, channels, mel_bins, time = hidden_states.shape
445449
hidden_states = hidden_states.reshape(batch, channels * mel_bins, time)
446450
hidden_states = jnp.transpose(hidden_states, (0, 2, 1))
451+
print(f"Prepared hidden_states for conv_in - shape: {hidden_states.shape}")
447452

448453
hidden_states = self.conv_in(hidden_states)
454+
print(f"After conv_in - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
449455

450456
for i in range(self.num_upsample_layers):
451457
if self.act_fn == "leaky_relu":
452458
hidden_states = jax.nn.leaky_relu(hidden_states, negative_slope=self.negative_slope)
453459
hidden_states = self.upsamplers[i](hidden_states)
460+
print(f"After upsampler {i} - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
454461

455462
start = i * self.resnets_per_upsample
456463
end = (i + 1) * self.resnets_per_upsample
@@ -460,16 +467,20 @@ def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
460467
res_sum = res_sum + self.resnets[j](hidden_states)
461468

462469
hidden_states = res_sum / self.resnets_per_upsample
470+
print(f"After resnets level {i} - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
463471

464472
hidden_states = self.act_out(hidden_states)
465473
hidden_states = self.conv_out(hidden_states)
474+
print(f"After conv_out - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
466475

467476
if self.final_act_fn == "tanh":
468477
hidden_states = jnp.tanh(hidden_states)
469478
elif self.final_act_fn == "clamp":
470479
hidden_states = jnp.clip(hidden_states, -1, 1)
471480

472481
hidden_states = jnp.transpose(hidden_states, (0, 2, 1))
482+
print(f"Final LTX2Vocoder output - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
483+
print(f"-----------------------------------")
473484
return hidden_states
474485

475486

0 commit comments

Comments
 (0)