Skip to content

Commit 0274615

Browse files
committed
qkv debug
1 parent a36a679 commit 0274615

1 file changed

Lines changed: 16 additions & 33 deletions

File tree

before_transformer_parity_maxdiffusion.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,13 @@ def patched_replace(self, hidden_states, attention_mask):
8181
mask = attention_mask.squeeze(-1) # [B, T]
8282
curr_mask = (mask > 0.5).astype(jnp.int32)
8383

84-
jax.debug.print("\n[MAXDIFFUSION] Mask Debug:")
85-
jax.debug.print(" Input Attn Mask min/max: {} / {}", jnp.min(attention_mask), jnp.max(attention_mask))
86-
jax.debug.print(" Curr Mask sum: {} (valid tokens)", jnp.sum(curr_mask))
87-
jax.debug.print(" Curr Mask start 20 elements: {}", curr_mask[0, :20])
84+
jax.debug.print("\n[MAXDIFFUSION MASK DEBUG]")
85+
jax.debug.print(" Input Attn Mask sum: {}", jnp.sum(attention_mask))
86+
jax.debug.print(" Curr Mask sum: {} (start elements: {})", jnp.sum(curr_mask), curr_mask[0, :10])
8887

8988
flipped = jnp.flip(curr_mask, axis=[1])
90-
jax.debug.print(" Flipped Mask Fwd logic sum: {} (first 20 elements: {})", jnp.sum(flipped), flipped[0, :20])
89+
jax.debug.print(" Flipped Mask (Diffusers Logic) sum: {} (start elements: {})", jnp.sum(flipped), flipped[0, :10])
9190

92-
regs = self.learnable_registers.value
93-
jax.debug.print(" [MAXDIFFUSION] Connector Registers std: {std}, mean: {mean}, min: {min}",
94-
std=jnp.std(regs), mean=jnp.mean(regs), min=jnp.min(regs))
95-
9691
return orig_replace(self, hidden_states, attention_mask)
9792

9893
Embeddings1DConnector._replace_padded_with_learnable_registers = patched_replace
@@ -101,32 +96,20 @@ def patched_replace(self, hidden_states, attention_mask):
10196

10297
orig_block_call = _BasicTransformerBlock1D.__call__
10398
def patched_block_call(self, hidden_states, attention_mask=None, rotary_emb=None):
104-
jax.debug.print("[MAXDIFFUSION] Block Input std: {std}, min: {min}, max: {max}",
105-
std=jnp.std(hidden_states), min=jnp.min(hidden_states), max=jnp.max(hidden_states))
99+
jax.debug.print("\n[MAXDIFFUSION W] to_q std: {k}, to_q bias: {b}",
100+
k=jnp.std(self.attn1.to_q.kernel), b=jnp.std(self.attn1.to_q.bias))
101+
jax.debug.print("[MAXDIFFUSION W] to_k std: {k}, to_k bias: {b}",
102+
k=jnp.std(self.attn1.to_k.kernel), b=jnp.std(self.attn1.to_k.bias))
103+
jax.debug.print("[MAXDIFFUSION W] to_v std: {k}, to_v bias: {b}",
104+
k=jnp.std(self.attn1.to_v.kernel), b=jnp.std(self.attn1.to_v.bias))
105+
jax.debug.print("[MAXDIFFUSION W] to_out std: {k}, to_out bias: {b}",
106+
k=jnp.std(self.attn1.to_out.kernel), b=jnp.std(self.attn1.to_out.bias))
107+
jax.debug.print("[MAXDIFFUSION W] norm_q std: {k}", k=jnp.std(self.attn1.norm_q.scale))
106108

107-
# 1. Norm -> Attention
108-
normed1 = self.norm1(hidden_states)
109-
jax.debug.print(" [MAXDIFFUSION] norm1 std: {std}, min: {min}, max: {max}",
110-
std=jnp.std(normed1), min=jnp.min(normed1), max=jnp.max(normed1))
109+
if attention_mask is not None:
110+
jax.debug.print("[MAXDIFFUSION MASK] supplied to attention kernel sum: {}", jnp.sum(attention_mask))
111111

112-
attn_output = self.attn1(normed1, attention_mask=attention_mask, rotary_emb=rotary_emb)
113-
jax.debug.print(" [MAXDIFFUSION] attn1 std: {std}, min: {min}, max: {max}",
114-
std=jnp.std(attn_output), min=jnp.min(attn_output), max=jnp.max(attn_output))
115-
116-
hidden_states = hidden_states + attn_output
117-
118-
# 2. Norm -> FeedForward
119-
normed2 = self.norm2(hidden_states)
120-
jax.debug.print(" [MAXDIFFUSION] norm2 std: {std}, min: {min}, max: {max}",
121-
std=jnp.std(normed2), min=jnp.min(normed2), max=jnp.max(normed2))
122-
123-
ff_output = self.ff(normed2)
124-
jax.debug.print(" [MAXDIFFUSION] ff std: {std}, min: {min}, max: {max}",
125-
std=jnp.std(ff_output), min=jnp.min(ff_output), max=jnp.max(ff_output))
126-
127-
hidden_states = hidden_states + ff_output
128-
129-
return hidden_states
112+
return orig_block_call(self, hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
130113

131114
_BasicTransformerBlock1D.__call__ = patched_block_call
132115

0 commit comments

Comments
 (0)