Skip to content

Commit a1f291c

Browse files
committed
some debug added to understand key_seq_len
1 parent 0d559d5 commit a1f291c

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ def wrap_flash_attention(query, key, value):
285285

286286
block_kv = max(*block_kv_sizes)
287287
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
288+
print("Key seq len")
289+
print(key_seq_len)
288290
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
289291

290292
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))

0 commit comments

Comments
 (0)