Skip to content

Commit 46a68be

Browse files
committed
debug in embeddings connector
1 parent c3101a5 commit 46a68be

1 file changed

Lines changed: 16 additions & 12 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,18 +182,20 @@ def __call__(
182182
) -> Tuple[Array, Array]:
183183

184184
# Debug print 1: Start
185-
print(f"\\nDEBUG: Embeddings1DConnector Start. hidden_states shape: {hidden_states.shape}")
186-
_t_np = jax.device_get(hidden_states)
187-
print(f" min: {_t_np.min():.5f}, max: {_t_np.max():.5f}, mean: {_t_np.mean():.5f}, std: {_t_np.std():.5f}")
185+
jax.debug.print("\\nDEBUG: Embeddings1DConnector Start. hidden_states shape: {}", hidden_states.shape)
186+
jax.debug.print(" min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}",
187+
min=jnp.min(hidden_states), max=jnp.max(hidden_states),
188+
mean=jnp.mean(hidden_states), std=jnp.std(hidden_states))
188189

189190
# 1. Thinking Tokens
190191
if self.num_learnable_registers > 0 and attention_mask is not None:
191192
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
192193

193194
# Debug print 2: After Padding Replacement
194-
print(f"DEBUG: After replacing padded with registers. hidden_states shape: {hidden_states.shape}")
195-
_t_np = jax.device_get(hidden_states)
196-
print(f" min: {_t_np.min():.5f}, max: {_t_np.max():.5f}, mean: {_t_np.mean():.5f}, std: {_t_np.std():.5f}")
195+
jax.debug.print("DEBUG: After replacing padded with registers.")
196+
jax.debug.print(" min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}",
197+
min=jnp.min(hidden_states), max=jnp.max(hidden_states),
198+
mean=jnp.mean(hidden_states), std=jnp.std(hidden_states))
197199

198200
# 2. RoPE
199201
seq_len = hidden_states.shape[1]
@@ -217,16 +219,18 @@ def block_scan_fn(carry, block_module):
217219
)(hidden_states, self.stacked_blocks)
218220

219221
# Debug print 3: After scan
220-
print(f"DEBUG: After transformer blocks scan. hidden_states shape: {hidden_states.shape}")
221-
_t_np = jax.device_get(hidden_states)
222-
print(f" min: {_t_np.min():.5f}, max: {_t_np.max():.5f}, mean: {_t_np.mean():.5f}, std: {_t_np.std():.5f}")
222+
jax.debug.print("DEBUG: After transformer blocks scan.")
223+
jax.debug.print(" min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}",
224+
min=jnp.min(hidden_states), max=jnp.max(hidden_states),
225+
mean=jnp.mean(hidden_states), std=jnp.std(hidden_states))
223226

224227
# 4. Final Norm
225228
hidden_states = self.final_norm(hidden_states)
226229

227230
# Debug print 4: Final Norm
228-
print(f"DEBUG: After final norm. hidden_states shape: {hidden_states.shape}")
229-
_t_np = jax.device_get(hidden_states)
230-
print(f" min: {_t_np.min():.5f}, max: {_t_np.max():.5f}, mean: {_t_np.mean():.5f}, std: {_t_np.std():.5f}")
231+
jax.debug.print("DEBUG: After final norm.")
232+
jax.debug.print(" min: {min:.5f}, max: {max:.5f}, mean: {mean:.5f}, std: {std:.5f}",
233+
min=jnp.min(hidden_states), max=jnp.max(hidden_states),
234+
mean=jnp.mean(hidden_states), std=jnp.std(hidden_states))
231235

232236
return hidden_states, attention_mask

0 commit comments

Comments
 (0)