@@ -438,19 +438,26 @@ def __init__(
438438 )
439439
440440 def __call__ (self , hidden_states : Array , time_last : bool = False ) -> Array :
441+ print (f"--- LTX2Vocoder Internal Debug ---" )
442+ print (f"Input hidden_states - shape: { hidden_states .shape } , min: { hidden_states .min ()} , max: { hidden_states .max ()} " )
443+
441444 if not time_last :
442445 hidden_states = jnp .transpose (hidden_states , (0 , 1 , 3 , 2 ))
446+ print (f"Transposed hidden_states - shape: { hidden_states .shape } " )
443447
444448 batch , channels , mel_bins , time = hidden_states .shape
445449 hidden_states = hidden_states .reshape (batch , channels * mel_bins , time )
446450 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 1 ))
451+ print (f"Prepared hidden_states for conv_in - shape: { hidden_states .shape } " )
447452
448453 hidden_states = self .conv_in (hidden_states )
454+ print (f"After conv_in - shape: { hidden_states .shape } , min: { hidden_states .min ()} , max: { hidden_states .max ()} " )
449455
450456 for i in range (self .num_upsample_layers ):
451457 if self .act_fn == "leaky_relu" :
452458 hidden_states = jax .nn .leaky_relu (hidden_states , negative_slope = self .negative_slope )
453459 hidden_states = self .upsamplers [i ](hidden_states )
460+ print (f"After upsampler { i } - shape: { hidden_states .shape } , min: { hidden_states .min ()} , max: { hidden_states .max ()} " )
454461
455462 start = i * self .resnets_per_upsample
456463 end = (i + 1 ) * self .resnets_per_upsample
@@ -460,16 +467,20 @@ def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
460467 res_sum = res_sum + self .resnets [j ](hidden_states )
461468
462469 hidden_states = res_sum / self .resnets_per_upsample
470+ print (f"After resnets level { i } - shape: { hidden_states .shape } , min: { hidden_states .min ()} , max: { hidden_states .max ()} " )
463471
464472 hidden_states = self .act_out (hidden_states )
465473 hidden_states = self .conv_out (hidden_states )
474+ print (f"After conv_out - shape: { hidden_states .shape } , min: { hidden_states .min ()} , max: { hidden_states .max ()} " )
466475
467476 if self .final_act_fn == "tanh" :
468477 hidden_states = jnp .tanh (hidden_states )
469478 elif self .final_act_fn == "clamp" :
470479 hidden_states = jnp .clip (hidden_states , - 1 , 1 )
471480
472481 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 1 ))
482+ print (f"Final LTX2Vocoder output - shape: { hidden_states .shape } , min: { hidden_states .min ()} , max: { hidden_states .max ()} " )
483+ print (f"-----------------------------------" )
473484 return hidden_states
474485
475486
0 commit comments