Skip to content

Commit c3101a5

Browse files
committed
debug in embeddings connector
1 parent 49f2a47 commit c3101a5

2 files changed

Lines changed: 30 additions & 31 deletions

File tree

check_learnable_registers.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,7 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
133133
num_duplications = t // self.num_learnable_registers
134134
registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1))
135135

136-
if attention_mask.ndim == 4:
137-
mask = attention_mask.squeeze(1).squeeze(1)
138-
elif attention_mask.ndim == 2:
136+
if attention_mask.ndim == 2:
139137
mask = attention_mask
140138
else:
141139
mask = attention_mask.squeeze(-1) # [B, T]
@@ -155,15 +153,16 @@ def _replace_padded_with_learnable_registers(self, hidden_states: Array, attenti
155153
shifted_hidden_states = jnp.zeros_like(hidden_states)
156154
shifted_hidden_states = shifted_hidden_states.at[b_idx, target_indices, :].set(hidden_states)
157155

156+
# Shift mask
157+
shifted_mask = jnp.zeros_like(curr_mask)
158+
shifted_mask = shifted_mask.at[b_idx, target_indices].set(curr_mask)
159+
158160
# 2. Add Learnable Registers
159-
# Where flipped_mask is 1, keep valid tokens. Where 0, insert registers.
160-
flipped_mask = jnp.flip(curr_mask, axis=[1])
161-
flipped_mask_expanded = flipped_mask[..., None]
162-
163-
output = jnp.where(flipped_mask_expanded == 1, shifted_hidden_states, registers)
161+
# Where shifted_mask is 1, keep valid tokens. Where 0, insert registers.
162+
output = jnp.where(shifted_mask[..., None] == 1, shifted_hidden_states, registers)
164163

165-
# Overwrite attention_mask with all-zeros since padding is now filled with registers.
166-
new_mask = jnp.zeros_like(attention_mask)
164+
# Overwrite attention_mask with all-ones since padding is now filled with registers.
165+
new_mask = jnp.ones_like(attention_mask)
167166
return output, new_mask
168167

169168
def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
@@ -181,9 +180,20 @@ def __call__(
181180
hidden_states: Array,
182181
attention_mask: Optional[Array] = None,
183182
) -> Tuple[Array, Array]:
183+
184+
# Debug print 1: Start
185+
print(f"\\nDEBUG: Embeddings1DConnector Start. hidden_states shape: {hidden_states.shape}")
186+
_t_np = jax.device_get(hidden_states)
187+
print(f" min: {_t_np.min():.5f}, max: {_t_np.max():.5f}, mean: {_t_np.mean():.5f}, std: {_t_np.std():.5f}")
188+
184189
# 1. Thinking Tokens
185190
if self.num_learnable_registers > 0 and attention_mask is not None:
186191
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
192+
193+
# Debug print 2: After Padding Replacement
194+
print(f"DEBUG: After replacing padded with registers. hidden_states shape: {hidden_states.shape}")
195+
_t_np = jax.device_get(hidden_states)
196+
print(f" min: {_t_np.min():.5f}, max: {_t_np.max():.5f}, mean: {_t_np.mean():.5f}, std: {_t_np.std():.5f}")
187197

188198
# 2. RoPE
189199
seq_len = hidden_states.shape[1]
@@ -205,8 +215,18 @@ def block_scan_fn(carry, block_module):
205215
in_axes=(nnx.Carry, 0), # Scan over the layers dimension (0) of block_module
206216
out_axes=(nnx.Carry, 0),
207217
)(hidden_states, self.stacked_blocks)
218+
219+
# Debug print 3: After scan
220+
print(f"DEBUG: After transformer blocks scan. hidden_states shape: {hidden_states.shape}")
221+
_t_np = jax.device_get(hidden_states)
222+
print(f" min: {_t_np.min():.5f}, max: {_t_np.max():.5f}, mean: {_t_np.mean():.5f}, std: {_t_np.std():.5f}")
208223

209224
# 4. Final Norm
210225
hidden_states = self.final_norm(hidden_states)
211226

227+
# Debug print 4: Final Norm
228+
print(f"DEBUG: After final norm. hidden_states shape: {hidden_states.shape}")
229+
_t_np = jax.device_get(hidden_states)
230+
print(f" min: {_t_np.min():.5f}, max: {_t_np.max():.5f}, mean: {_t_np.mean():.5f}, std: {_t_np.std():.5f}")
231+
212232
return hidden_states, attention_mask

0 commit comments

Comments
 (0)