Skip to content

Commit bdcdeb1

Browse files
committed
vocoder debug
1 parent e6105e6 commit bdcdeb1

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,10 +470,10 @@ def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
470470
print(f"After resnets level {i} - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
471471

472472
hidden_states = self.act_out(hidden_states)
473-
print(f"After act_out - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
473+
jax.debug.print("After act_out - min: {min}, max: {max}", min=hidden_states.min(), max=hidden_states.max())
474474
print(f"conv_out kernel - min: {self.conv_out.kernel.value.min()}, max: {self.conv_out.kernel.value.max()}")
475475
hidden_states = self.conv_out(hidden_states)
476-
print(f"After conv_out - shape: {hidden_states.shape}, min: {hidden_states.min()}, max: {hidden_states.max()}")
476+
jax.debug.print("After conv_out - min: {min}, max: {max}", min=hidden_states.min(), max=hidden_states.max())
477477

478478
if self.final_act_fn == "tanh":
479479
hidden_states = jnp.tanh(hidden_states)

0 commit comments

Comments
 (0)