Skip to content

Commit 6f9af83

Browse files
committed
mask debug
1 parent 2b162af commit 6f9af83

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

before_transformer_parity_maxdiffusion.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,22 @@ def patched_fe_call(self, hidden_states, attention_mask):
7575

7676
orig_replace = Embeddings1DConnector._replace_padded_with_learnable_registers
7777
def patched_replace(self, hidden_states, attention_mask):
78+
if attention_mask.ndim == 2:
79+
mask = attention_mask
80+
else:
81+
mask = attention_mask.squeeze(-1) # [B, T]
82+
curr_mask = (mask > 0.5).astype(jnp.int32)
83+
84+
jax.debug.print("\n[MAXDIFFUSION] Mask Debug:")
85+
jax.debug.print(" Input Attn Mask min/max: {} / {}", jnp.min(attention_mask), jnp.max(attention_mask))
86+
jax.debug.print(" Curr Mask sum: {} (valid tokens)", jnp.sum(curr_mask))
87+
jax.debug.print(" Curr Mask start 20 elements: {}", curr_mask[0, :20])
88+
89+
flipped = jnp.flip(curr_mask, axis=[1])
90+
jax.debug.print(" Flipped Mask Fwd logic sum: {} (first 20 elements: {})", jnp.sum(flipped), flipped[0, :20])
91+
7892
regs = self.learnable_registers.value
79-
jax.debug.print("[MAXDIFFUSION] Connector Registers std: {std}, mean: {mean}, min: {min}",
93+
jax.debug.print(" [MAXDIFFUSION] Connector Registers std: {std}, mean: {mean}, min: {min}",
8094
std=jnp.std(regs), mean=jnp.mean(regs), min=jnp.min(regs))
8195

8296
return orig_replace(self, hidden_states, attention_mask)

0 commit comments

Comments
 (0)