Skip to content

Commit 1c69263

Browse files
committed
remove print statements.
1 parent d5f28aa commit 1c69263

1 file changed

Lines changed: 0 additions & 8 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,6 @@ def _tpu_flash_attention(
206206
check_rep=False,
207207
)
208208
def wrap_flash_attention(query, key, value):
209-
jax.debug.print("query.shape: {x}", x=query.shape)
210-
jax.debug.print("key.shape: {x}", x=key.shape)
211-
jax.debug.print("value.shape: {x}", x=value.shape)
212209

213210
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
214211
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
@@ -250,11 +247,6 @@ def wrap_flash_attention(query, key, value):
250247

251248
m = m_new
252249
l = l_new
253-
jax.debug.print("Loop {i}: max(m)={m_max}, max(l)={l_max}, max(o)={o_max}",
254-
i=i,
255-
m_max=m.max(),
256-
l_max=l.max(),
257-
o_max=o.max())
258250

259251
attention_output = o / l[..., None]
260252

0 commit comments

Comments
 (0)