Skip to content

Commit cd25c5d

Browse files
committed
added debug for values
1 parent 740f2ee commit cd25c5d

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def wrap_flash_attention(query, key, value):
296296
kv_padded_len = key.shape[2]
297297
print("KV padded len:", kv_padded_len)
298298
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
299+
print("KV indices:", kv_indices)
299300
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
300301
print("KV segment ids:", kv_segment_ids)
301302
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)

0 commit comments

Comments
 (0)