Skip to content

Commit 70c656e

Browse files
committed
added debug for values
1 parent cd25c5d commit 70c656e

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,9 @@ def wrap_flash_attention(query, key, value):
296296
kv_padded_len = key.shape[2]
297297
print("KV padded len:", kv_padded_len)
298298
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
299-
print("KV indices:", kv_indices)
299+
jax.debug.print("KV indices:", kv_indices)
300300
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
301-
print("KV segment ids:", kv_segment_ids)
301+
jax.debug.print("KV segment ids:", kv_segment_ids)
302302
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
303303

304304
# make_splash_mha is wrapped around shardmap and seq and head is already

0 commit comments

Comments
 (0)